]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
go : improve progress reporting and callback handling (#1024)
authorBo-Yi Wu <redacted>
Sun, 25 Jun 2023 11:07:55 +0000 (19:07 +0800)
committerGitHub <redacted>
Sun, 25 Jun 2023 11:07:55 +0000 (14:07 +0300)
- Rename `cb` to `callNewSegment` in the `Process` function
- Add `callProgress` as a new parameter to the `Process` function
- Introduce `ProgressCallback` type for reporting progress during processing
- Update `Whisper_full` function to include `progressCallback` parameter
- Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks

Signed-off-by: appleboy <redacted>
bindings/go/Makefile
bindings/go/pkg/whisper/context.go
bindings/go/pkg/whisper/interface.go
bindings/go/whisper.go
bindings/go/whisper_test.go

index 6be2979948e1b82b1430813e062ee96f95b12fa2..74118262b60bec7b3e4dd7adc74e0d5419ffb4c4 100644 (file)
@@ -32,7 +32,7 @@ mkdir:
 modtidy:
        @go mod tidy
 
-clean: 
+clean:
        @echo Clean
        @rm -fr $(BUILD_DIR)
        @go clean
index 593b32b37b09d0b9ece3ac4e0dcdc7a08f0d9b17..e193832e98a2a49dd6df9da98dae139f63b46ac2 100644 (file)
@@ -152,7 +152,11 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
 }
 
 // Process new sample data and return any errors
-func (context *context) Process(data []float32, cb SegmentCallback) error {
+func (context *context) Process(
+       data []float32,
+       callNewSegment SegmentCallback,
+       callProgress ProgressCallback,
+) error {
        if context.model.ctx == nil {
                return ErrInternalAppError
        }
@@ -165,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
        processors := 0
        if processors > 1 {
                if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
-                       if cb != nil {
+                       if callNewSegment != nil {
                                num_segments := context.model.ctx.Whisper_full_n_segments()
                                s0 := num_segments - new
                                for i := s0; i < num_segments; i++ {
-                                       cb(toSegment(context.model.ctx, i))
+                                       callNewSegment(toSegment(context.model.ctx, i))
                                }
                        }
                }); err != nil {
                        return err
                }
        } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
-               if cb != nil {
+               if callNewSegment != nil {
                        num_segments := context.model.ctx.Whisper_full_n_segments()
                        s0 := num_segments - new
                        for i := s0; i < num_segments; i++ {
-                               cb(toSegment(context.model.ctx, i))
+                               callNewSegment(toSegment(context.model.ctx, i))
                        }
                }
+       }, func(progress int) {
+               if callProgress != nil {
+                       callProgress(progress)
+               }
        }); err != nil {
                return err
        }
index e65fed178d5b295349145b00d4cc8b135e6fd01a..dc9c66df9c3e042f90dee9cecd26094083807c68 100644 (file)
@@ -12,6 +12,10 @@ import (
 // time. It is called during the Process function
 type SegmentCallback func(Segment)
 
+// ProgressCallback is the callback function for reporting progress during
+// processing. It is called during the Process function
+type ProgressCallback func(int)
+
 // Model is the interface to a whisper model. Create a new model with the
 // function whisper.New(string)
 type Model interface {
@@ -47,7 +51,7 @@ type Context interface {
        // Process mono audio data and return any errors.
        // If defined, newly generated segments are passed to the
        // callback function during processing.
-       Process([]float32, SegmentCallback) error
+       Process([]float32, SegmentCallback, ProgressCallback) error
 
        // After process is called, return segments until the end of the stream
        // is reached, when io.EOF is returned.
index babadf006c2b2ebe71b44ead38a7ad456b27cfec..451f3f8d6563787e4bc3dd4c6f29614aa9cab9db 100644 (file)
@@ -15,6 +15,7 @@ import (
 #include <stdlib.h>
 
 extern void callNewSegment(void* user_data, int new);
+extern void callProgress(void* user_data, int progress);
 extern bool callEncoderBegin(void* user_data);
 
 // Text segment callback
@@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
     }
 }
 
+// Progress callback
+// Called on every newly generated text segment
+// Use the whisper_full_...() functions to obtain the text segments
+static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
+    if(user_data != NULL && ctx != NULL) {
+        callProgress(user_data, progress);
+    }
+}
+
 // Encoder begin callback
 // If not NULL, called before the encoder starts
 // If it returns false, the computation is aborted
@@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
        params.new_segment_callback_user_data = (void*)(ctx);
        params.encoder_begin_callback = whisper_encoder_begin_cb;
        params.encoder_begin_callback_user_data = (void*)(ctx);
+       params.progress_callback = whisper_progress_cb;
+       params.progress_callback_user_data = (void*)(ctx);
        return params;
 }
 */
@@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
 
 // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
 // Uses the specified decoding strategy to obtain the text.
-func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+func (ctx *Context) Whisper_full(
+       params Params,
+       samples []float32,
+       encoderBeginCallback func() bool,
+       newSegmentCallback func(int),
+       progressCallback func(int),
+) error {
        registerEncoderBeginCallback(ctx, encoderBeginCallback)
        registerNewSegmentCallback(ctx, newSegmentCallback)
+       registerProgressCallback(ctx, progressCallback)
        defer registerEncoderBeginCallback(ctx, nil)
        defer registerNewSegmentCallback(ctx, nil)
+       defer registerProgressCallback(ctx, nil)
        if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
                return nil
        } else {
@@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
 
 var (
        cbNewSegment   = make(map[unsafe.Pointer]func(int))
+       cbProgress     = make(map[unsafe.Pointer]func(int))
        cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
 )
 
@@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
        }
 }
 
+func registerProgressCallback(ctx *Context, fn func(int)) {
+       if fn == nil {
+               delete(cbProgress, unsafe.Pointer(ctx))
+       } else {
+               cbProgress[unsafe.Pointer(ctx)] = fn
+       }
+}
+
 func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
        if fn == nil {
                delete(cbEncoderBegin, unsafe.Pointer(ctx))
@@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
        }
 }
 
+//export callProgress
+func callProgress(user_data unsafe.Pointer, progress C.int) {
+       if fn, ok := cbProgress[user_data]; ok {
+               fn(int(progress))
+       }
+}
+
 //export callEncoderBegin
 func callEncoderBegin(user_data unsafe.Pointer) C.bool {
        if fn, ok := cbEncoderBegin[user_data]; ok {
index 2c95c81ff0a4f3833467d04f3f0379fef4a62b06..40648ffa8d4e59d026820474ef2c17ffb577aed0 100644 (file)
@@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
        defer ctx.Whisper_free()
        params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
        data := buf.AsFloat32Buffer().Data
-       err = ctx.Whisper_full(params, data, nil, nil)
+       err = ctx.Whisper_full(params, data, nil, nil, nil)
        assert.NoError(err)
 
        // Print out tokens