]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : fix external encoder (#1860)
authorGeorgi Gerganov <redacted>
Mon, 12 Feb 2024 17:53:51 +0000 (19:53 +0200)
committerGitHub <redacted>
Mon, 12 Feb 2024 17:53:51 +0000 (19:53 +0200)
whisper.cpp

index dec995709a02d7502fece29cf4c160d1a36aa666..536adc3396dc469941e940b23caeb8b98523a1d9 100644 (file)
@@ -1659,22 +1659,9 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         ggml_set_name(cur, "embd_conv");
         wstate.embd_conv = cur;
     } else {
-#ifdef WHISPER_USE_COREML
-        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-        ggml_allocr_alloc(alloc, cur);
+        ggml_build_forward_expand(gf, mel);
 
-        if (!ggml_allocr_is_measure(alloc)) {
-            whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
-        }
-#endif
-#ifdef WHISPER_USE_OPENVINO
         cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-        ggml_allocr_alloc(alloc, cur);
-
-        if (!ggml_allocr_is_measure(alloc)) {
-            whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
-        }
-#endif
 
         ggml_set_name(cur, "embd_enc");
         wstate.embd_enc = cur;
@@ -1708,14 +1695,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
 
-    //ggml_allocr * alloc = wstate.alloc_encode.alloc;
-
-    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
-    //ggml_allocr_alloc(alloc, cur);
-
-    //if (!ggml_allocr_is_measure(alloc)) {
-    //    ggml_backend_tensor_copy(wstate.embd_conv, cur);
-    //}
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
 
     const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
@@ -1957,14 +1936,6 @@ static struct ggml_cgraph * whisper_build_graph_cross(
 
     ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    //ggml_allocr * alloc = wstate.alloc_cross.alloc;
-
-    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-    //ggml_allocr_alloc(alloc, cur);
-
-    //if (!ggml_allocr_is_measure(alloc)) {
-    //    ggml_backend_tensor_copy(wstate.embd_enc, cur);
-    //}
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
     const float  Kscale = pow(float(n_state) / n_head, -0.25);
@@ -2037,13 +2008,13 @@ static bool whisper_encode_internal(
             return false;
         }
 
+        struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
+
         // set the input
         {
             const auto & mel_inp = wstate.mel;
             const int n_ctx      = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
 
-            struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
-
             assert(mel->type == GGML_TYPE_F32);
             assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
 
@@ -2068,6 +2039,12 @@ static bool whisper_encode_internal(
             if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
                 return false;
             }
+        } else {
+#if defined(WHISPER_USE_COREML)
+            whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
+#elif defined(WHISPER_USE_OPENVINO)
+            whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
+#endif
         }
     }