]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : abort callback improvements (#1345)
authormkiol <redacted>
Sun, 8 Oct 2023 14:22:24 +0000 (16:22 +0200)
committerGitHub <redacted>
Sun, 8 Oct 2023 14:22:24 +0000 (17:22 +0300)
* whisper : initialize abort_callback to null

* whisper : add example how to use abort_callback

examples/main/main.cpp
whisper.cpp

index 60c1cca756a683de1321e72a38838957e464b629..cdd16ac73a3b5971dc6100d23e835a611a3dd394 100644 (file)
@@ -944,8 +944,9 @@ int main(int argc, char ** argv) {
                 wparams.progress_callback_user_data = &user_data;
             }
 
-            // example for abort mechanism
-            // in this example, we do not abort the processing, but we could if the flag is set to true
+            // examples for abort mechanism
+            // in examples below, we do not abort the processing, but we could if the flag is set to true
+
             // the callback is called before every encoder run - if it returns false, the processing is aborted
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
@@ -957,6 +958,17 @@ int main(int argc, char ** argv) {
                 wparams.encoder_begin_callback_user_data = &is_aborted;
             }
 
+            // the callback is called before every computation - if it returns true, the computation is aborted
+            {
+                static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+                wparams.abort_callback = [](void * user_data) {
+                    bool is_aborted = *(bool*)user_data;
+                    return is_aborted;
+                };
+                wparams.abort_callback_user_data = &is_aborted;
+            }
+
             if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 10;
index 403c2d09b9c9087252da591f805425a9b3365ad6..ccac6aaf581c1565a94c8f02ab45ac90fe79baad 100644 (file)
@@ -3773,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.encoder_begin_callback           =*/ nullptr,
         /*.encoder_begin_callback_user_data =*/ nullptr,
 
+        /*.abort_callback           =*/ nullptr,
+        /*.abort_callback_user_data =*/ nullptr,
+
         /*.logits_filter_callback           =*/ nullptr,
         /*.logits_filter_callback_user_data =*/ nullptr,
     };