}
// Process new sample data and return any errors
-func (context *context) Process(data []float32, cb SegmentCallback) error {
+func (context *context) Process(
+ data []float32,
+ callNewSegment SegmentCallback,
+ callProgress ProgressCallback,
+) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
- if cb != nil {
+ if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
- cb(toSegment(context.model.ctx, i))
+ callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
- if cb != nil {
+ if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
- cb(toSegment(context.model.ctx, i))
+ callNewSegment(toSegment(context.model.ctx, i))
}
}
+ }, func(progress int) {
+ if callProgress != nil {
+ callProgress(progress)
+ }
}); err != nil {
return err
}
// time. It is called during the Process function
type SegmentCallback func(Segment)
+// ProgressCallback is the callback function for reporting progress during
+// processing. It is called during the Process function
+type ProgressCallback func(int)
+
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
- Process([]float32, SegmentCallback) error
+ Process([]float32, SegmentCallback, ProgressCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
+extern void callProgress(void* user_data, int progress);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
}
}
+// Progress callback
+// Called on every newly generated text segment
+// Use the whisper_full_...() functions to obtain the text segments
+static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
+ if(user_data != NULL && ctx != NULL) {
+ callProgress(user_data, progress);
+ }
+}
+
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
+ params.progress_callback = whisper_progress_cb;
+ params.progress_callback_user_data = (void*)(ctx);
return params;
}
*/
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
-func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
+func (ctx *Context) Whisper_full(
+ params Params,
+ samples []float32,
+ encoderBeginCallback func() bool,
+ newSegmentCallback func(int),
+ progressCallback func(int),
+) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
+ registerProgressCallback(ctx, progressCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
+ defer registerProgressCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
+ cbProgress = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
}
}
+func registerProgressCallback(ctx *Context, fn func(int)) {
+ if fn == nil {
+ delete(cbProgress, unsafe.Pointer(ctx))
+ } else {
+ cbProgress[unsafe.Pointer(ctx)] = fn
+ }
+}
+
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
}
}
+//export callProgress
+func callProgress(user_data unsafe.Pointer, progress C.int) {
+ if fn, ok := cbProgress[user_data]; ok {
+ fn(int(progress))
+ }
+}
+
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {