]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add abort callback (#1335)
authormkiol <redacted>
Wed, 4 Oct 2023 08:57:55 +0000 (10:57 +0200)
committerGitHub <redacted>
Wed, 4 Oct 2023 08:57:55 +0000 (11:57 +0300)
whisper.cpp
whisper.h

index 916883c36e935017c8361d02a626eabbfc6fff88..403c2d09b9c9087252da591f805425a9b3365ad6 100644 (file)
@@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 // ggml helpers
 //
 
-static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+static void ggml_graph_compute_helper(
+        std::vector<uint8_t> & buf,
+                 ggml_cgraph * graph,
+                         int   n_threads,
+      whisper_abort_callback   abort_callback,
+                        void * abort_callback_data) {
     struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
 
+    plan.abort_callback = abort_callback;
+    plan.abort_callback_data = abort_callback_data;
+
     if (plan.work_size > 0) {
         buf.resize(plan.work_size);
         plan.work_data = buf.data();
@@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
         whisper_context & wctx,
           whisper_state & wstate,
               const int   mel_offset,
-              const int   n_threads) {
+              const int   n_threads,
+ whisper_abort_callback   abort_callback,
+                   void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
     // conv
@@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
     }
 
@@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
     const whisper_token * tokens,
               const int   n_tokens,
               const int   n_past,
-              const int   n_threads) {
+              const int   n_threads,
+ whisper_abort_callback   abort_callback,
+                   void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
     const auto & model   = wctx.model;
@@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -3290,7 +3302,7 @@ int whisper_set_mel(
 }
 
 int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
-    if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return -1;
     }
@@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
 }
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
-    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return -1;
     }
@@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
 int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
     const int selected_decoder_id = 0;
 
-    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
         return false;
     }
 
-    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -4594,7 +4606,7 @@ int whisper_full_with_state(
         }
 
         // encode audio features starting at offset seek
-        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
+        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
             log("%s: failed to encode\n", __func__);
             return -6;
         }
@@ -4677,7 +4689,7 @@ int whisper_full_with_state(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                     log("%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -4901,7 +4913,7 @@ int whisper_full_with_state(
 
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
-                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                         log("%s: failed to decode\n", __func__);
                         return -8;
                     }
@@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_helper(work, &gf, n_threads);
+            ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_helper(work, &gf, n_threads);
+                ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
 
                 const int64_t t1 = ggml_time_us();
 
index 6c0efc15870b18545f53971cc66acbcdbb8e589a..c3118c9c99b7ef0d6b217f4f6ff8ceae1d0a2b8f 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -334,6 +334,11 @@ extern "C" {
     // If it returns false, the computation is aborted
     typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
 
+    // Abort callback
+    // If not NULL, called before ggml computation
+    // If it returns true, the computation is aborted
+    typedef bool (*whisper_abort_callback)(void * user_data);
+
     // Logits filter callback
     // Can be used to modify the logits before sampling
     // If not NULL, called after applying temperature to logits
@@ -428,6 +433,10 @@ extern "C" {
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;
 
+        // called each time before ggml computation starts
+        whisper_abort_callback abort_callback;
+        void * abort_callback_user_data;
+
         // called by each decoder to filter obtained logits
         whisper_logits_filter_callback logits_filter_callback;
         void * logits_filter_callback_user_data;