]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Revert "whisper : remove extra backend instance (huh?)" (#2182)
authorGeorgi Gerganov <redacted>
Mon, 27 May 2024 07:20:25 +0000 (10:20 +0300)
committerGitHub <redacted>
Mon, 27 May 2024 07:20:25 +0000 (10:20 +0300)
This reverts commit 4caa64b73ed4c0e71097c865b0f6a9c136b007c6.

whisper.cpp

index 84aec8238cdb42d19cab3ef2e97a5aa0b91a2a94..7b8c683fca72585d0c5be637a703dc03383fed23 100644 (file)
@@ -818,6 +818,8 @@ struct whisper_state {
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
+    ggml_backend_t backend = nullptr;
+
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
     // - stores the actual tensor data into the `data` buffers
@@ -2261,7 +2263,7 @@ static bool whisper_encode_internal(
         }
 
         if (!whisper_encode_external(wstate)) {
-            if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
+            if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
                 return false;
             }
         } else {
@@ -2284,7 +2286,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2300,7 +2302,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2801,7 +2803,7 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -3248,6 +3250,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     whisper_state * state = new whisper_state;
 
+    state->backend = whisper_backend_init(ctx->params);
+    if (!state->backend) {
+        WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
+        whisper_free_state(state);
+        return nullptr;
+    }
+
     // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
     // in theory, there can be a case where this is not enough, but in practice it should always be enough
     const int factor = 3;
@@ -3684,6 +3693,8 @@ void whisper_free_state(struct whisper_state * state) {
         ggml_gallocr_free(state->alloc_cross.alloc);
         ggml_gallocr_free(state->alloc_decode.alloc);
 
+        ggml_backend_free(state->backend);
+
         // [EXPERIMENTAL] Token-level timestamps with DTW
         aheads_masks_free(state->aheads_masks);