]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add mechanism for aborting the whisper_full() computation
authorGeorgi Gerganov <redacted>
Sun, 27 Nov 2022 18:28:36 +0000 (20:28 +0200)
committerGeorgi Gerganov <redacted>
Sun, 27 Nov 2022 18:42:45 +0000 (20:42 +0200)
examples/main/main.cpp
whisper.cpp
whisper.h

index 569404caa49840566a62d3d00a298e9cfeab3f24..465d43fb0796455bdfce509ab1a5b661500bb58d 100644 (file)
@@ -607,6 +607,19 @@ int main(int argc, char ** argv) {
                 wparams.new_segment_callback_user_data = &user_data;
             }
 
+            // example for abort mechanism
+            // in this example, we do not abort the processing, but we could if the flag is set to true
+            // the callback is called before every encoder run - if it returns false, the processing is aborted
+            {
+                static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+                wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
+                    bool is_aborted = *(bool*)user_data;
+                    return !is_aborted;
+                };
+                wparams.encoder_begin_callback_user_data = &is_aborted;
+            }
+
             if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 10;
index 2daf41165d704c91e6fdba7490cbfb34592d014e..fbcb5d14c0323ff81780bdae48cba9a360c1aaa3 100644 (file)
@@ -2451,6 +2451,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
                     /*.new_segment_callback           =*/ nullptr,
                     /*.new_segment_callback_user_data =*/ nullptr,
+
+                    /*.encoder_begin_callback           =*/ nullptr,
+                    /*.encoder_begin_callback_user_data =*/ nullptr,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
@@ -2497,6 +2500,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
                     /*.new_segment_callback           =*/ nullptr,
                     /*.new_segment_callback_user_data =*/ nullptr,
+
+                    /*.encoder_begin_callback           =*/ nullptr,
+                    /*.encoder_begin_callback_user_data =*/ nullptr,
                 };
             } break;
     }
@@ -2659,6 +2665,13 @@ int whisper_full(
             break;
         }
 
+        if (params.encoder_begin_callback) {
+            if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+                fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
+                break;
+            }
+        }
+
         // encode audio features starting at offset seek
         if (whisper_encode(ctx, seek, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
index 4b5fbccd4e3c0e35b6ddddef3ad3ac5b9cee6d57..156edbbf45433adc46bab4ccd5ced437421e4050 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -185,6 +185,14 @@ extern "C" {
     // Use the whisper_full_...() functions to obtain the text segments
     typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
 
+    // Encoder begin callback
+    // If not NULL, called before the encoder starts
+    // If it returns false, the computation is aborted
+    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+
+    // Parameters for the whisper_full() function
+    // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
+    // whisper_full_default_params()
     struct whisper_full_params {
         enum whisper_sampling_strategy strategy;
 
@@ -231,6 +239,9 @@ extern "C" {
 
         whisper_new_segment_callback new_segment_callback;
         void * new_segment_callback_user_data;
+
+        whisper_encoder_begin_callback encoder_begin_callback;
+        void * encoder_begin_callback_user_data;
     };
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);