]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : sync with whisper.cpp
authorGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:39:24 +0000 (22:39 +0300)
committerGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:39:24 +0000 (22:39 +0300)
examples/whisper/whisper.cpp

index 71f7a96c67be322fc18ad51d6edffd28403dfc9d..24b9e5d8bd042c5ee9969da44058f68d1a675a09 100644 (file)
@@ -1511,8 +1511,7 @@ static bool whisper_encode_internal(
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
-                        );
+                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
 
             struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
 #else
@@ -1736,10 +1735,12 @@ static bool whisper_encode_internal(
 
             wstate.use_buf(ctx0, -1);
 
-            //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
-            struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
+            Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+            struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+            struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+                    (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
+                    (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
 
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1874,20 +1875,24 @@ static bool whisper_decode_internal(
 
             Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                    layer.attn_v_w,
-                    cur);
-
-            Vcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_v_b,
-                        Vcur),
-                    Vcur);
-
             // store key and value to memory
             {
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+                        layer.attn_v_w,
+                        cur);
+
+                Vcur = ggml_add(ctx0,
+                        ggml_repeat(ctx0,
+                            layer.attn_v_b,
+                            Vcur),
+                        Vcur);
+
+                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+
                 struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
 
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -1926,16 +1931,14 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
-                                n_state/n_head, n_head, n_past + N),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_past + N, n_state/n_head, n_head,
+                        n_ctx*ggml_element_size(kv_self.v),
+                        n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
+                        il*n_ctx*ggml_element_size(kv_self.v)*n_state);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
@@ -1998,15 +2001,22 @@ static bool whisper_decode_internal(
                         ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
-            struct ggml_tensor * Vcross =
-                ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
-                        n_state/n_head, n_head, M);
+            //struct ggml_tensor * Vcross =
+            //    ggml_reshape_3d(ctx0,
+            //            ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
+            //            n_state/n_head, n_head, M);
 
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+            //struct ggml_tensor * V_trans =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, wstate.kv_cross.v,
+                        M, n_state/n_head, n_head,
+                        M*ggml_element_size(wstate.kv_cross.v),
+                        M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+                        il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
 
             // ------
 
@@ -2033,7 +2043,7 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);