]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
go : add Encoder Begin Callback (#2900)
authorAmanda Der Bedrosian <redacted>
Wed, 19 Mar 2025 07:05:04 +0000 (00:05 -0700)
committerGitHub <redacted>
Wed, 19 Mar 2025 07:05:04 +0000 (09:05 +0200)
Adding in EncoderBeginCallback to the Context's Process callback.
This optional callback function returns false if computation should
be aborted.

Co-authored-by: Amanda Der Bedrosian <redacted>
bindings/go/README.md
bindings/go/examples/go-whisper/process.go
bindings/go/pkg/whisper/context.go
bindings/go/pkg/whisper/context_test.go
bindings/go/pkg/whisper/interface.go

index 6958ede80f25afdcf33e53633d90faed7641d707..cbd2a622874cb841ad4b5f870e79cecb5c7b2822 100644 (file)
@@ -31,7 +31,7 @@ func main() {
        if err != nil {
                panic(err)
        }
-       if err := context.Process(samples, nil, nil); err != nil {
+       if err := context.Process(samples, nil, nil, nil); err != nil {
                return err
        }
 
index 71e52f0100069191c1321098d15f6df7b401b836..833947e843caa8a3369c6de9c57916f886671671 100644 (file)
@@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
        // Process the data
        fmt.Fprintf(flags.Output(), "  ...processing %q\n", path)
        context.ResetTimings()
-       if err := context.Process(data, cb, nil); err != nil {
+       if err := context.Process(data, nil, cb, nil); err != nil {
                return err
        }
 
index 06376b1b870f00b7a83e8e1884e4581bc4e23b35..a806129364badac90813e44c5bd8955b9c431570 100644 (file)
@@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
 // Process new sample data and return any errors
 func (context *context) Process(
        data []float32,
+       callEncoderBegin EncoderBeginCallback,
        callNewSegment SegmentCallback,
        callProgress ProgressCallback,
 ) error {
@@ -203,7 +204,20 @@ func (context *context) Process(
        // We don't do parallel processing at the moment
        processors := 0
        if processors > 1 {
-               if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
+               if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
+                       func(new int) {
+                               if callNewSegment != nil {
+                                       num_segments := context.model.ctx.Whisper_full_n_segments()
+                                       s0 := num_segments - new
+                                       for i := s0; i < num_segments; i++ {
+                                               callNewSegment(toSegment(context.model.ctx, i))
+                                       }
+                               }
+                       }); err != nil {
+                       return err
+               }
+       } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
+               func(new int) {
                        if callNewSegment != nil {
                                num_segments := context.model.ctx.Whisper_full_n_segments()
                                s0 := num_segments - new
@@ -211,22 +225,11 @@ func (context *context) Process(
                                        callNewSegment(toSegment(context.model.ctx, i))
                                }
                        }
-               }); err != nil {
-                       return err
-               }
-       } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
-               if callNewSegment != nil {
-                       num_segments := context.model.ctx.Whisper_full_n_segments()
-                       s0 := num_segments - new
-                       for i := s0; i < num_segments; i++ {
-                               callNewSegment(toSegment(context.model.ctx, i))
+               }, func(progress int) {
+                       if callProgress != nil {
+                               callProgress(progress)
                        }
-               }
-       }, func(progress int) {
-               if callProgress != nil {
-                       callProgress(progress)
-               }
-       }); err != nil {
+               }); err != nil {
                return err
        }
 
index 7d83a8dffeb797edd9d01b8585a63a5437851995..510514817345a38be1b5781e4275cdb0364d7e66 100644 (file)
@@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
        context, err := model.NewContext()
        assert.NoError(err)
 
-       err = context.Process(data, nil, nil)
+       err = context.Process(data, nil, nil, nil)
        assert.NoError(err)
 }
index 8981b1a8116669bdb5d92c42c2ad2c55483ba4fd..2b6a9c8ecfa1175ac70dcf1369a1441e59be96c7 100644 (file)
@@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
 // processing. It is called during the Process function
 type ProgressCallback func(int)
 
+// EncoderBeginCallback is the callback function for checking if we want to
+// continue processing. It is called during the Process function
+type EncoderBeginCallback func() bool
+
 // Model is the interface to a whisper model. Create a new model with the
 // function whisper.New(string)
 type Model interface {
@@ -31,7 +35,7 @@ type Model interface {
        Languages() []string
 }
 
-// Context is the speach recognition context.
+// Context is the speech recognition context.
 type Context interface {
        SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
        SetTranslate(bool)        // Set translate flag
@@ -58,7 +62,7 @@ type Context interface {
        // Process mono audio data and return any errors.
        // If defined, newly generated segments are passed to the
        // callback function during processing.
-       Process([]float32, SegmentCallback, ProgressCallback) error
+       Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
 
        // After process is called, return segments until the end of the stream
        // is reached, when io.EOF is returned.