]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add progress callback (#600)
authorpajowu <redacted>
Thu, 30 Mar 2023 17:29:29 +0000 (19:29 +0200)
committerGitHub <redacted>
Thu, 30 Mar 2023 17:29:29 +0000 (20:29 +0300)
whisper.cpp
whisper.h

index 13e11141cc95e2cc8cc30296ce5dfb88e3607e34..95b6d33905d16c487c0c1d495a6f76443be44821 100644 (file)
@@ -3152,6 +3152,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.new_segment_callback           =*/ nullptr,
         /*.new_segment_callback_user_data =*/ nullptr,
 
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+
         /*.encoder_begin_callback           =*/ nullptr,
         /*.encoder_begin_callback_user_data =*/ nullptr,
 
@@ -3868,6 +3871,10 @@ int whisper_full_with_state(
                 fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
             }
         }
+        if (params.progress_callback) {
+            params.progress_callback(
+                ctx, ctx->state, progress_prev, params.progress_callback_user_data);
+        }
 
         // of only 1 second left, then stop
         if (seek + 100 >= seek_end) {
@@ -4456,6 +4463,9 @@ int whisper_full_parallel(
         params_cur.new_segment_callback = nullptr;
         params_cur.new_segment_callback_user_data = nullptr;
 
+        params_cur.progress_callback = nullptr;
+        params_cur.progress_callback_user_data = nullptr;
+
         workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
     }
 
index fa6bff4fc8da098eed2ea831b865a29ec7433d48..a96c96c927e2b31cdb624c4fdedb07719be88c59 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -306,6 +306,9 @@ extern "C" {
     // Use the whisper_full_...() functions to obtain the text segments
     typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
 
+    // Progress callback
+    typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
+
     // Encoder begin callback
     // If not NULL, called before the encoder starts
     // If it returns false, the computation is aborted
@@ -392,6 +395,10 @@ extern "C" {
         whisper_new_segment_callback new_segment_callback;
         void * new_segment_callback_user_data;
 
+        // called on each progress update
+        whisper_progress_callback progress_callback;
+        void * progress_callback_user_data;
+
         // called each time before the encoder starts
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;