]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : move progress calculation out of whisper.cpp (#1081)
authorHrishikesh Barman <redacted>
Tue, 25 Jul 2023 15:53:34 +0000 (21:23 +0530)
committerGitHub <redacted>
Tue, 25 Jul 2023 15:53:34 +0000 (18:53 +0300)
Current `progress_step` was hardcoded into whisper.cpp, this resulted in
bindings having to access progress only at that step even if progress
callback was being called at every iteration.

With this change we get greater granularity progress reporting from
whisper.cpp and bindings/implementations can define their own progress step.

examples/main/main.cpp
whisper.cpp

index 8dd31d028b1e732bf6572230dab0fd8cc835d721..4fbc3f69ad22811a046c52714900da3f7fb1ed36 100644 (file)
@@ -59,6 +59,7 @@ struct whisper_params {
     int32_t offset_t_ms  =  0;
     int32_t offset_n     =  0;
     int32_t duration_ms  =  0;
+    int32_t progress_step =  5;
     int32_t max_context  = -1;
     int32_t max_len      =  0;
     int32_t best_of      =  2;
@@ -218,6 +219,7 @@ struct whisper_print_user_data {
     const whisper_params * params;
 
     const std::vector<std::vector<float>> * pcmf32s;
+    int progress_prev;
 };
 
 std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
@@ -252,6 +254,14 @@ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s
 
     return speaker;
 }
+void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
+    int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
+    int * progress_prev  = &(((whisper_print_user_data *) user_data)->progress_prev);
+    if (progress >= *progress_prev + progress_step) {
+        *progress_prev += progress_step;
+        fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
+    }
+}
 
 void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
@@ -895,7 +905,7 @@ int main(int argc, char ** argv) {
             wparams.entropy_thold    = params.entropy_thold;
             wparams.logprob_thold    = params.logprob_thold;
 
-            whisper_print_user_data user_data = { &params, &pcmf32s };
+            whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
@@ -903,6 +913,11 @@ int main(int argc, char ** argv) {
                 wparams.new_segment_callback_user_data = &user_data;
             }
 
+            if (wparams.print_progress) {
+                wparams.progress_callback           = whisper_print_progress_callback;
+                wparams.progress_callback_user_data = &user_data;
+            }
+
             // example for abort mechanism
             // in this example, we do not abort the processing, but we could if the flag is set to true
             // the callback is called before every encoder run - if it returns false, the processing is aborted
index 381874573ea48a9a7ee307c4e6a00711241be728..ab734fea66aec67ca7844d6ffc5889710ac283a5 100644 (file)
@@ -4163,9 +4163,6 @@ int whisper_full_with_state(
         }
     }
 
-    int progress_prev = 0;
-    int progress_step = 5;
-
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;
@@ -4193,15 +4190,9 @@ int whisper_full_with_state(
     // main loop
     while (true) {
         const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
-        while (progress_cur >= progress_prev + progress_step) {
-            progress_prev += progress_step;
-            if (params.print_progress) {
-                fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
-            }
-        }
         if (params.progress_callback) {
             params.progress_callback(
-                ctx, ctx->state, progress_prev, params.progress_callback_user_data);
+                ctx, ctx->state, progress_cur, params.progress_callback_user_data);
         }
 
         // of only 1 second left, then stop