]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
go : add temperature options (#2417)
authorBinozo <redacted>
Fri, 20 Sep 2024 12:45:36 +0000 (14:45 +0200)
committerGitHub <redacted>
Fri, 20 Sep 2024 12:45:36 +0000 (15:45 +0300)
* Fixed go cuda bindings building

* Added note to go bindings Readme to build using cuda support

* Added temperature bindings for Go

---------

Co-authored-by: Binozo <redacted>
bindings/go/params.go
bindings/go/pkg/whisper/context.go
bindings/go/pkg/whisper/interface.go

index 9c075b6a2cbd8457ae040d0c5ef2bb748447f555..95c5bfaf9345d0177667d8754e090a659b335ee0 100644 (file)
@@ -131,6 +131,16 @@ func (p *Params) SetEntropyThold(t float32) {
        p.entropy_thold = C.float(t)
 }
 
+func (p *Params) SetTemperature(t float32) {
+       p.temperature = C.float(t)
+}
+
+// Sets the fallback temperature incrementation
+// Pass -1.0 to disable this feature
+func (p *Params) SetTemperatureFallback(t float32) {
+       p.temperature_inc = C.float(t)
+}
+
 // Set initial prompt
 func (p *Params) SetInitialPrompt(prompt string) {
        p.initial_prompt = C.CString(prompt)
@@ -162,6 +172,8 @@ func (p *Params) String() string {
        str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
        str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt))
        str += fmt.Sprintf(" entropy_thold=%f", p.entropy_thold)
+       str += fmt.Sprintf(" temperature=%f", p.temperature)
+       str += fmt.Sprintf(" temperature_inc=%f", p.temperature_inc)
        str += fmt.Sprintf(" beam_size=%d", p.beam_search.beam_size)
        if p.translate {
                str += " translate"
index dc34aa18bb87f66736b51b166c92cfaccf7caca5..06376b1b870f00b7a83e8e1884e4581bc4e23b35 100644 (file)
@@ -140,6 +140,17 @@ func (context *context) SetEntropyThold(t float32) {
        context.params.SetEntropyThold(t)
 }
 
+// Set Temperature
+func (context *context) SetTemperature(t float32) {
+       context.params.SetTemperature(t)
+}
+
+// Set the fallback temperature incrementation
+// Pass -1.0 to disable this feature
+func (context *context) SetTemperatureFallback(t float32) {
+       context.params.SetTemperatureFallback(t)
+}
+
 // Set initial prompt
 func (context *context) SetInitialPrompt(prompt string) {
        context.params.SetInitialPrompt(prompt)
index 6eb692ef6102b351fff7452dcb19dbd0b1ecfe78..8981b1a8116669bdb5d92c42c2ad2c55483ba4fd 100644 (file)
@@ -38,20 +38,22 @@ type Context interface {
        IsMultilingual() bool     // Return true if the model is multilingual.
        Language() string         // Get language
 
-       SetOffset(time.Duration)        // Set offset
-       SetDuration(time.Duration)      // Set duration
-       SetThreads(uint)                // Set number of threads to use
-       SetSplitOnWord(bool)            // Set split on word flag
-       SetTokenThreshold(float32)      // Set timestamp token probability threshold
-       SetTokenSumThreshold(float32)   // Set timestamp token sum probability threshold
-       SetMaxSegmentLength(uint)       // Set max segment length in characters
-       SetTokenTimestamps(bool)        // Set token timestamps flag
-       SetMaxTokensPerSegment(uint)    // Set max tokens per segment (0 = no limit)
-       SetAudioCtx(uint)               // Set audio encoder context
-       SetMaxContext(n int)            // Set maximum number of text context tokens to store
-       SetBeamSize(n int)              // Set Beam Size
-       SetEntropyThold(t float32)      // Set Entropy threshold
-       SetInitialPrompt(prompt string) // Set initial prompt
+       SetOffset(time.Duration)          // Set offset
+       SetDuration(time.Duration)        // Set duration
+       SetThreads(uint)                  // Set number of threads to use
+       SetSplitOnWord(bool)              // Set split on word flag
+       SetTokenThreshold(float32)        // Set timestamp token probability threshold
+       SetTokenSumThreshold(float32)     // Set timestamp token sum probability threshold
+       SetMaxSegmentLength(uint)         // Set max segment length in characters
+       SetTokenTimestamps(bool)          // Set token timestamps flag
+       SetMaxTokensPerSegment(uint)      // Set max tokens per segment (0 = no limit)
+       SetAudioCtx(uint)                 // Set audio encoder context
+       SetMaxContext(n int)              // Set maximum number of text context tokens to store
+       SetBeamSize(n int)                // Set Beam Size
+       SetEntropyThold(t float32)        // Set Entropy threshold
+       SetInitialPrompt(prompt string)   // Set initial prompt
+       SetTemperature(t float32)         // Set temperature
+       SetTemperatureFallback(t float32) // Set temperature incrementation
 
        // Process mono audio data and return any errors.
        // If defined, newly generated segments are passed to the