]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : allow non-CoreML fallback when Core ML cannot be loaded (#812)
authorCanis Lupus <redacted>
Sat, 29 Apr 2023 07:49:02 +0000 (08:49 +0100)
committerGitHub <redacted>
Sat, 29 Apr 2023 07:49:02 +0000 (10:49 +0300)
if the Core ML model cannot be loaded, continue without Core ML instead of
returning. This allows a single build to transcribe using Core ML models
where available, and regular models when not.

whisper.cpp

index 583f2d8e32085a7ef2f82784acebfcca51791ed7..9abdb6c0829a2a45deae6d8c827784ae7f698100 100644 (file)
@@ -592,7 +592,7 @@ struct whisper_state {
 
     std::string path_model; // populated by whisper_init_from_file()
 #ifdef WHISPER_USE_COREML
-    whisper_coreml_context * ctx_coreml;
+    whisper_coreml_context * ctx_coreml = nullptr;
 #endif
 
     // [EXPERIMENTAL] token-level timestamps data
@@ -1385,320 +1385,331 @@ static bool whisper_encode_internal(
         }
     }
 
-#ifndef WHISPER_USE_COREML
     struct ggml_tensor * cur;
 
-    // convolution + gelu
+#ifndef WHISPER_USE_COREML
+    const bool use_coreml = false;
+#else
+    const bool use_coreml = wstate.ctx_coreml != nullptr;
+#endif
+
+    if (!use_coreml)
     {
-        wstate.use_buf(ctx0, 1);
+        // convolution + gelu
+        {
+            wstate.use_buf(ctx0, 1);
 
-        cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
-        cur = ggml_add(ctx0,
-            ggml_repeat(ctx0,
-                model.e_conv_1_b,
-                cur),
-            cur);
+            cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
+            cur = ggml_add(ctx0,
+                ggml_repeat(ctx0,
+                    model.e_conv_1_b,
+                    cur),
+                cur);
 
-        cur = ggml_gelu(ctx0, cur);
+            cur = ggml_gelu(ctx0, cur);
 
-        wstate.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
-        cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
-        cur = ggml_add(ctx0,
-            ggml_repeat(ctx0,
-                model.e_conv_2_b,
-                cur),
-            cur);
+            cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
+            cur = ggml_add(ctx0,
+                ggml_repeat(ctx0,
+                    model.e_conv_2_b,
+                    cur),
+                cur);
 
-        cur = ggml_gelu(ctx0, cur);
-    }
+            cur = ggml_gelu(ctx0, cur);
+        }
 
-    wstate.use_buf(ctx0, 3);
+        wstate.use_buf(ctx0, 3);
 
-    // ===================================================================
-    // NOTE: experimenting with partial evaluation of the encoder (ignore)
-    //static int iter = -1;
-    //const int n_iter = 1500/n_ctx;
+        // ===================================================================
+        // NOTE: experimenting with partial evaluation of the encoder (ignore)
+        //static int iter = -1;
+        //const int n_iter = 1500/n_ctx;
 
-    //iter = (iter + 1) % n_iter;
+        //iter = (iter + 1) % n_iter;
 
-    //if (iter == 0) {
-    //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
-    //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
-    //}
+        //if (iter == 0) {
+        //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+        //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+        //}
 
-    static int iter = 0;
+        static int iter = 0;
 
-    const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
-    const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+        const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+        const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
 
-    struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+        struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
-    cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+        cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
 
-    // ===================================================================
+        // ===================================================================
 
-    // original:
-    //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+        // original:
+        //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
 
-    struct ggml_tensor * inpL = cur;
+        struct ggml_tensor * inpL = cur;
 
-    for (int il = 0; il < n_layer; ++il) {
-        const auto & layer = model.layers_encoder[il];
+        for (int il = 0; il < n_layer; ++il) {
+            const auto & layer = model.layers_encoder[il];
 
-        // norm
-        {
-            wstate.use_buf(ctx0, 0);
+            // norm
+            {
+                wstate.use_buf(ctx0, 0);
 
-            cur = ggml_norm(ctx0, inpL);
+                cur = ggml_norm(ctx0, inpL);
 
-            // cur = ln_0_w*cur + ln_0_b
-            cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
-                    cur),
-                ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
-        }
+                // cur = ln_0_w*cur + ln_0_b
+                cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
+                        cur),
+                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
+            }
 
-        // self-attention
-        {
-            wstate.use_buf(ctx0, 1);
+            // self-attention
+            {
+                wstate.use_buf(ctx0, 1);
 
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
-                layer.attn_q_w,
-                cur);
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
+                    layer.attn_q_w,
+                    cur);
 
-            Qcur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    layer.attn_q_b,
-                    Qcur),
-                Qcur);
+                Qcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
+                        layer.attn_q_b,
+                        Qcur),
+                    Qcur);
 
-            //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+                //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            // note: no bias for Key
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
-                layer.attn_k_w,
-                cur);
+                // note: no bias for Key
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
+                    layer.attn_k_w,
+                    cur);
 
-            //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+                //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                layer.attn_v_w,
-                cur);
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+                    layer.attn_v_w,
+                    cur);
 
-            Vcur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    layer.attn_v_b,
-                    Vcur),
-                Vcur);
+                Vcur = ggml_add(ctx0,
+                    ggml_repeat(ctx0,
+                        layer.attn_v_b,
+                        Vcur),
+                    Vcur);
 
-            // ------
+                // ------
 
-            wstate.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
-#ifdef WHISPER_USE_FLASH_ATTN
-            struct ggml_tensor * Q =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Qcur,
-                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
+    #ifdef WHISPER_USE_FLASH_ATTN
+                struct ggml_tensor * Q =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Qcur,
+                                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                struct ggml_tensor * K =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Kcur,
+                                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                struct ggml_tensor * V =
+                    ggml_cpy(ctx0,
+                            ggml_permute(ctx0,
+                                ggml_reshape_3d(ctx0,
+                                    Vcur,
+                                    n_state/n_head, n_head, n_ctx),
+                                1, 2, 0, 3),
+                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
+
+                struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
+    #else
+                struct ggml_tensor * Q =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Qcur,
+                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                struct ggml_tensor * K =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Kcur,
+                                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+                struct ggml_tensor * KQ_scaled =
+                    ggml_scale(ctx0,
+                            KQ,
+                            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
+                            );
+
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
+
+                //struct ggml_tensor * V_trans =
+                //    ggml_permute(ctx0,
+                //            ggml_cpy(ctx0,
+                //                Vcur,
+                //                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
+                //            1, 2, 0, 3);
+
+                //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+                struct ggml_tensor * V =
+                    ggml_cpy(ctx0,
+                            ggml_permute(ctx0,
+                                ggml_reshape_3d(ctx0,
+                                    Vcur,
+                                    n_state/n_head, n_head, n_ctx),
+                                0, 2, 1, 3),
+                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
+                            );
+
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
+    #endif
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Kcur,
-                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
+                wstate.use_buf(ctx0, 1);
 
-            struct ggml_tensor * V =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                Vcur,
-                                n_state/n_head, n_head, n_ctx),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
-
-            struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
-#else
-            struct ggml_tensor * Q =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
+                cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+            }
 
-            struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Kcur,
-                            ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
+            // projection
+            {
+                wstate.use_buf(ctx0, 0);
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+                cur = ggml_mul_mat(ctx0,
+                    layer.attn_ln_1_w,
+                    cur);
 
-            struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctx0,
-                        KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
-                        );
+                wstate.use_buf(ctx0, 1);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
+                cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
+                    cur);
+            }
 
-            //struct ggml_tensor * V_trans =
-            //    ggml_permute(ctx0,
-            //            ggml_cpy(ctx0,
-            //                Vcur,
-            //                ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
-            //            1, 2, 0, 3);
+            wstate.use_buf(ctx0, 2);
 
-            //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            // add the input
+            cur = ggml_add(ctx0, cur, inpL);
 
-            struct ggml_tensor * V =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                Vcur,
-                                n_state/n_head, n_head, n_ctx),
-                            0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
-                        );
-
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
-#endif
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            struct ggml_tensor * inpFF = cur;
 
-            wstate.use_buf(ctx0, 1);
+            // feed-forward network
+            {
+                // norm
+                {
+                    wstate.use_buf(ctx0, 0);
 
-            cur = ggml_cpy(ctx0,
-                KQV_merged,
-                ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
-        }
+                    cur = ggml_norm(ctx0, inpFF);
 
-        // projection
-        {
-            wstate.use_buf(ctx0, 0);
+                    wstate.use_buf(ctx0, 1);
 
-            cur = ggml_mul_mat(ctx0,
-                layer.attn_ln_1_w,
-                cur);
+                    // cur = mlp_ln_w*cur + mlp_ln_b
+                    cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
+                            cur),
+                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
+                }
 
-            wstate.use_buf(ctx0, 1);
+    #ifdef WHISPER_USE_FLASH_FF
+                wstate.use_buf(ctx0, 0);
 
-            cur = ggml_add(ctx0,
-                ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
-                cur);
-        }
+                cur = ggml_flash_ff(ctx0,
+                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
+                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+    #else
+                wstate.use_buf(ctx0, 0);
 
-        wstate.use_buf(ctx0, 2);
+                // fully connected
+                cur = ggml_mul_mat(ctx0,
+                    layer.mlp_0_w,
+                    cur);
 
-        // add the input
-        cur = ggml_add(ctx0, cur, inpL);
+                wstate.use_buf(ctx0, 1);
 
-        struct ggml_tensor * inpFF = cur;
+                cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
+                    cur);
 
-        // feed-forward network
-        {
-            // norm
-            {
                 wstate.use_buf(ctx0, 0);
 
-                cur = ggml_norm(ctx0, inpFF);
+                // GELU activation
+                cur = ggml_gelu(ctx0, cur);
 
                 wstate.use_buf(ctx0, 1);
 
-                // cur = mlp_ln_w*cur + mlp_ln_b
-                cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.mlp_ln_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.mlp_ln_b, cur));
-            }
+                // projection
+                cur = ggml_mul_mat(ctx0,
+                    layer.mlp_1_w,
+                    cur);
 
-#ifdef WHISPER_USE_FLASH_FF
-            wstate.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
-            cur = ggml_flash_ff(ctx0,
-                ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
-                layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
-#else
-            wstate.use_buf(ctx0, 0);
+                cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
+                    cur);
+    #endif
+            }
 
-            // fully connected
-            cur = ggml_mul_mat(ctx0,
-                layer.mlp_0_w,
-                cur);
+            wstate.use_buf(ctx0, 3);
 
-            wstate.use_buf(ctx0, 1);
+            inpL = ggml_add(ctx0, cur, inpFF);
+        }
 
-            cur = ggml_add(ctx0,
-                ggml_repeat(ctx0, layer.mlp_0_b, cur),
-                cur);
+        cur = inpL;
 
+        // norm
+        {
             wstate.use_buf(ctx0, 0);
 
-            // GELU activation
-            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_norm(ctx0, cur);
 
             wstate.use_buf(ctx0, 1);
 
-            // projection
-            cur = ggml_mul_mat(ctx0,
-                layer.mlp_1_w,
-                cur);
-
-            wstate.use_buf(ctx0, 0);
-
+            // cur = ln_f_g*cur + ln_f_b
             cur = ggml_add(ctx0,
-                ggml_repeat(ctx0, layer.mlp_1_b, cur),
-                cur);
-#endif
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.e_ln_w, cur),
+                    cur),
+                ggml_repeat(ctx0, model.e_ln_b, cur));
         }
 
-        wstate.use_buf(ctx0, 3);
-
-        inpL = ggml_add(ctx0, cur, inpFF);
-    }
+        wstate.use_buf(ctx0, -1);
 
-    cur = inpL;
-
-    // norm
-    {
-        wstate.use_buf(ctx0, 0);
-
-        cur = ggml_norm(ctx0, cur);
+        // run the computation
+        {
+            struct ggml_cgraph gf = {};
+            gf.n_threads = n_threads;
 
-        wstate.use_buf(ctx0, 1);
+            ggml_build_forward_expand(&gf, cur);
+            ggml_graph_compute(ctx0, &gf);
 
-        // cur = ln_f_g*cur + ln_f_b
-        cur = ggml_add(ctx0,
-            ggml_mul(ctx0,
-                ggml_repeat(ctx0, model.e_ln_w, cur),
-                cur),
-            ggml_repeat(ctx0, model.e_ln_b, cur));
+            //ggml_graph_print(&gf);
+        }
     }
-
-    wstate.use_buf(ctx0, -1);
-
-    // run the computation
+#ifdef WHISPER_USE_COREML
+    else
     {
-        struct ggml_cgraph gf = {};
-        gf.n_threads = n_threads;
+        wstate.use_buf(ctx0, -1);
 
-        ggml_build_forward_expand(&gf, cur);
-        ggml_graph_compute(ctx0, &gf);
+        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
 
-        //ggml_graph_print(&gf);
+        whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
     }
-#else
-    wstate.use_buf(ctx0, -1);
-
-    struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-
-    whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
 #endif
 
     // cur
@@ -2569,10 +2580,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
     if (!state->ctx_coreml) {
         fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
+#ifndef WHISPER_COREML_ALLOW_FALLBACK        
         return nullptr;
+#endif
+    } else {
+        fprintf(stderr, "%s: Core ML model loaded\n", __func__);        
     }
-
-    fprintf(stderr, "%s: Core ML model loaded\n", __func__);
 #endif
 
     state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@@ -2745,8 +2758,10 @@ void whisper_free_state(struct whisper_state * state)
         }
 
 #ifdef WHISPER_USE_COREML
-        whisper_coreml_free(state->ctx_coreml);
-        state->ctx_coreml = nullptr;
+        if (state->ctx_coreml != nullptr) {
+            whisper_coreml_free(state->ctx_coreml);
+            state->ctx_coreml = nullptr;
+        }
 #endif
 
         delete state;