]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : use flash attention (whisper/2152)
authorGeorgi Gerganov <redacted>
Wed, 15 May 2024 06:38:19 +0000 (09:38 +0300)
committerGeorgi Gerganov <redacted>
Wed, 15 May 2024 07:37:24 +0000 (10:37 +0300)
* whisper : use flash attention in the encoder

* whisper : add kv_pad

* whisper : remove extra backend instance (huh?)

* whisper : use FA for cross-attention

* whisper : use FA for self-attention

* whisper : simplify encoder FA

* whisper : add flash_attn runtime parameter

* scripts : add bench log

* scripts : add M1 Pro bench log

examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index d11c1c3f81b0c76cd8000f2a817217f9db709613..45eb17fe7f327aada7bce233a60b695b04599bd2 100644 (file)
@@ -70,6 +70,7 @@ struct whisper_params {
     bool no_timestamps   = false;
     bool log_score       = false;
     bool use_gpu         = true;
+    bool flash_attn      = false;
 
     std::string language  = "en";
     std::string prompt;
@@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
-        else if (                  arg == "--suppress-regex")  { params.suppress_regex = argv[++i]; }
+        else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
+        else if (                  arg == "--suppress-regex")  { params.suppress_regex  = argv[++i]; }
         else if (                  arg == "--grammar")         { params.grammar         = argv[++i]; }
         else if (                  arg == "--grammar-rule")    { params.grammar_rule    = argv[++i]; }
         else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  -fa,       --flash-attn        [%-7s] flash attention\n",                                params.flash_attn ? "true" : "false");
     fprintf(stderr, "  --suppress-regex REGEX         [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
     fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
     fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
@@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
     // whisper init
 
     struct whisper_context_params cparams = whisper_context_default_params();
-    cparams.use_gpu = params.use_gpu;
+
+    cparams.use_gpu    = params.use_gpu;
+    cparams.flash_attn = params.flash_attn;
 
     if (!params.dtw.empty()) {
         cparams.dtw_token_timestamps = true;
index ff4223daf429ea416261f6fea3256fa2cdf203df..84aec8238cdb42d19cab3ef2e97a5aa0b91a2a94 100644 (file)
@@ -809,14 +809,15 @@ struct whisper_state {
     // shared between all decoders
     whisper_kv_cache kv_cross;
 
+    // padded buffer for flash-attention
+    whisper_kv_cache kv_pad;
+
     whisper_mel mel;
 
     whisper_batch batch;
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
-    ggml_backend_t backend = nullptr;
-
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
     // - stores the actual tensor data into the `data` buffers
@@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 }
 
 static bool kv_cache_init(
-        const struct whisper_hparams & hparams,
              struct whisper_kv_cache & cache,
                       ggml_backend_t   backend,
                            ggml_type   wtype,
+                             int64_t   n_text_state,
+                             int64_t   n_text_layer,
                                  int   n_ctx) {
-    const int64_t n_text_state = hparams.n_text_state;
-    const int64_t n_text_layer = hparams.n_text_layer;
-
     const int64_t n_mem      = n_text_layer*n_ctx;
     const int64_t n_elements = n_text_state*n_mem;
 
@@ -941,6 +940,8 @@ static bool kv_cache_init(
         return false;
     }
 
+    ggml_backend_buffer_clear(cache.buffer, 0);
+
     return true;
 }
 
@@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp(
     }
 }
 
+static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
+    if (!wctx.params.flash_attn) {
+        return 1u;
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(wctx.backend)) {
+        return 32u;
+    }
+#endif
+
+#ifdef GGML_USE_CUDA
+    if (ggml_backend_is_cuda(wctx.backend)) {
+        return 256u;
+    }
+#endif
+
+    return 1u;
+}
+
 // [EXPERIMENTAL] Token-level timestamps with DTW
 static bool aheads_masks_init(
         const whisper_context_params & cparams,
@@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
 
+    const int n_state_head = n_state/n_head;
+
+    auto & kv_pad = wstate.kv_pad;
+
+    WHISPER_ASSERT(!!kv_pad.ctx);
+
+    const int n_ctx_pad = GGML_PAD(n_ctx, 256);
+
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_encode.meta.size(),
         /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
@@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
 
-    const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
+    const float KQscale = 1.0f/sqrtf(float(n_state_head));
 
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
             Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
 
-            //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
+            //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
 
             // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
                     layer.attn_k_w,
                     cur);
 
-            //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
+            //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
 
             struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
                     layer.attn_v_w,
@@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                 ggml_permute(ctx0,
                         ggml_cpy(ctx0,
                             Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
-            struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Kcur,
-                            ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
-                        0, 2, 1, 3);
-
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            if (wctx.params.flash_attn) {
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+                struct ggml_tensor * K =
+                    ggml_view_3d(ctx0, kv_pad.k,
+                            n_state_head, n_ctx_pad, n_head,
+                            ggml_element_size(kv_pad.k)*n_state,
+                            ggml_element_size(kv_pad.k)*n_state_head,
+                            0);
 
-            struct ggml_tensor * V =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                Vcur,
-                                n_state/n_head, n_head, n_ctx),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
-                        );
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_pad.v,
+                            n_state_head, n_ctx_pad, n_head,
+                            ggml_element_size(kv_pad.v)*n_state,
+                            ggml_element_size(kv_pad.v)*n_state_head,
+                            0);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
+            } else {
+                struct ggml_tensor * K =
+                    ggml_permute(ctx0,
+                            ggml_cpy(ctx0,
+                                Kcur,
+                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
+                            0, 2, 1, 3);
+
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+
+                struct ggml_tensor * V =
+                    ggml_cpy(ctx0,
+                            ggml_permute(ctx0,
+                                ggml_reshape_3d(ctx0,
+                                    Vcur,
+                                    n_state_head, n_head, n_ctx),
+                                1, 2, 0, 3),
+                            ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
+                            );
+
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+            }
         }
 
         // projection
@@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross(
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
 
+    const int n_state_head = n_state/n_head;
+
+    const int n_ctx_pad = GGML_PAD(n_ctx, 256);
+
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_cross.meta.size(),
         /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
@@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
 
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
-    const float  Kscale = pow(float(n_state) / n_head, -0.25);
+    const float  Kscale = pow(float(n_state_head), -0.25);
 
     for (int il = 0; il < model.hparams.n_text_layer; ++il) {
         auto & layer = model.layers_decoder[il];
 
-        struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+        struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
                 layer.cross_attn_k_w,
                 cur);
 
         Kcross = ggml_scale(ctx0, Kcross, Kscale);
 
-        struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+        struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
                 layer.cross_attn_v_w,
                 cur);
 
@@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
                     Vcross,
                     layer.cross_attn_v_b);
 
-        Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+        struct ggml_tensor * k;
+        struct ggml_tensor * v;
 
-        struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
-                n_state*n_ctx,
-                (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+        if (wctx.params.flash_attn) {
+            k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
 
-        struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
-                (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
-                (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+            v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
+        } else {
+            Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+            k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
+                    (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+
+            v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+                    (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
+                    (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+        }
 
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2195,7 +2261,7 @@ static bool whisper_encode_internal(
         }
 
         if (!whisper_encode_external(wstate)) {
-            if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
                 return false;
             }
         } else {
@@ -2218,7 +2284,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2234,7 +2300,7 @@ static bool whisper_encode_internal(
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     const int n_head  = hparams.n_text_head;
     const int n_layer = hparams.n_text_layer;
 
+    const int n_state_head = n_state/n_head;
+
     const int n_tokens    = batch.n_tokens;
     const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
 
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+    const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
+
+    const int32_t n_kv    = worst_case ? n_ctx            : kv_self.n;
+    const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
 
     //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
 
@@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_set_name(position, "position");
     ggml_set_input(position);
 
-    const float KQscale = pow(float(n_state)/n_head, -0.25);
+    const float KQscale = pow(float(n_state_head), -0.25);
 
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
     ggml_set_name(KQ_mask, "KQ_mask");
     ggml_set_input(KQ_mask);
 
+    struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
+
     // token encoding + position encoding
     struct ggml_tensor * cur =
         ggml_add(ctx0,
@@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             Vcur,
                             layer.attn_v_b);
 
-                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
+                struct ggml_tensor * k;
+                struct ggml_tensor * v;
+
+                if (wctx.params.flash_attn) {
+                    k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
+                            (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+
+                    v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
+                            (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
+                } else {
+                    Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
 
-                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
-                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
-                        (   n_ctx)*ggml_element_size(kv_self.v),
-                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
+                    k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
+                            (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+
+                    v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
+                            (   n_ctx)*ggml_element_size(kv_self.v),
+                            (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
+                }
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
+                        ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_view_3d(ctx0, kv_self.k,
-                        n_state/n_head, n_kv, n_head,
+                        n_state_head, n_kv, n_head,
                         ggml_element_size(kv_self.k)*n_state,
-                        ggml_element_size(kv_self.k)*n_state/n_head,
+                        ggml_element_size(kv_self.k)*n_state_head,
                         ggml_element_size(kv_self.k)*n_state*n_ctx*il);
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            if (wctx.params.flash_attn) {
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_self.v,
+                            n_state_head, n_kv, n_head,
+                            ggml_element_size(kv_self.v)*n_state,
+                            ggml_element_size(kv_self.v)*n_state_head,
+                            ggml_element_size(kv_self.v)*n_state*n_ctx*il);
+
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
+
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
+            } else {
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
 
-            struct ggml_tensor * V =
-                ggml_view_3d(ctx0, kv_self.v,
-                        n_kv, n_state/n_head, n_head,
-                        n_ctx*ggml_element_size(kv_self.v),
-                        n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
-                        n_ctx*ggml_element_size(kv_self.v)*n_state*il);
+                struct ggml_tensor * V =
+                    ggml_view_3d(ctx0, kv_self.v,
+                            n_kv, n_state_head, n_head,
+                            n_ctx*ggml_element_size(kv_self.v),
+                            n_ctx*ggml_element_size(kv_self.v)*n_state_head,
+                            n_ctx*ggml_element_size(kv_self.v)*n_state*il);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+            }
         }
 
         // projection
@@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                         Qcur,
                         layer.cross_attn_q_b);
 
-            Qcur = ggml_scale(ctx0, Qcur, KQscale);
-
-            // Kcross is already scaled
-            struct ggml_tensor * Kcross =
-                ggml_view_3d(ctx0, wstate.kv_cross.k,
-                        n_state/n_head, n_audio_ctx, n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state,
-                        ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
-
-            //struct ggml_tensor * Vcross =
-            //    ggml_reshape_3d(ctx0,
-            //            ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
-            //            n_state/n_head, n_head, n_audio_ctx);
-
-            //struct ggml_tensor * V_trans =
-            //    ggml_cpy(ctx0,
-            //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-            //            ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
-
-            struct ggml_tensor * V =
-                ggml_view_3d(ctx0, wstate.kv_cross.v,
-                        n_audio_ctx, n_state/n_head, n_head,
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
-                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
-
-            // ------
-
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
+                        ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
-
-            //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale(ctx0,
-            //            KQ,
-            //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
-            //            );
+            if (wctx.params.flash_attn) {
+                struct ggml_tensor * Kcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.k,
+                            n_state_head, n_audio_ctx_pad, n_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state,
+                            ggml_element_size(wstate.kv_cross.k)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
 
-            // no masking for cross-attention
-            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+                struct ggml_tensor * Vcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.v,
+                            n_state_head, n_audio_ctx_pad, n_head,
+                            ggml_element_size(wstate.kv_cross.v)*n_state,
+                            ggml_element_size(wstate.kv_cross.v)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
+                cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
 
-            // [EXPERIMENTAL] Token-level timestamps with DTW
-            if (wctx.params.dtw_token_timestamps) {
-                if (wstate.aheads_masks.m[il] != nullptr) {
-                    struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
-                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
-                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
-                    aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
-                    if (aheads_cross_QKs == NULL) {
-                        aheads_cross_QKs = aheads_KQs;
-                    } else {
-                        aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
+                cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
+            } else {
+                struct ggml_tensor * Kcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.k,
+                            n_state_head, n_audio_ctx, n_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state,
+                            ggml_element_size(wstate.kv_cross.k)*n_state_head,
+                            ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
+
+                struct ggml_tensor * Vcross =
+                    ggml_view_3d(ctx0, wstate.kv_cross.v,
+                            n_audio_ctx, n_state_head, n_head,
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
+                            n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
+
+                // ------
+
+                // K * Q
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
+
+                struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
+
+                // [EXPERIMENTAL] Token-level timestamps with DTW
+                if (wctx.params.dtw_token_timestamps) {
+                    if (wstate.aheads_masks.m[il] != nullptr) {
+                        struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
+                        aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
+                        aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                        aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
+                        if (aheads_cross_QKs == NULL) {
+                            aheads_cross_QKs = aheads_KQs;
+                        } else {
+                            aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
+                        }
                     }
                 }
-            }
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
 
-            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            // cur = KQV_merged.contiguous().view(n_state, n_tokens)
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cpy(ctx0,
+                        KQV_merged,
+                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+            }
         }
 
         // projection
@@ -2638,7 +2733,9 @@ static bool whisper_decode_internal(
             return false;
         }
 
-        kv_self.n = whisper_kv_cache_cell_max(kv_self);
+        const uint32_t pad = whisper_kv_cache_get_padding(wctx);
+        kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
+
         //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
         //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
     }
@@ -2672,9 +2769,10 @@ static bool whisper_decode_internal(
             struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
 
             auto & kv_self = wstate.kv_self;
-            const int32_t n_kv     = kv_self.n;
 
-            wstate.inp_mask.resize(n_kv*n_tokens);
+            const int32_t n_kv = kv_self.n;
+
+            wstate.inp_mask.resize(ggml_nelements(KQ_mask));
 
             float * data = wstate.inp_mask.data();
             memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2690,6 +2788,12 @@ static bool whisper_decode_internal(
                         }
                     }
                 }
+
+                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                    for (int j = 0; j < n_kv; ++j) {
+                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                    }
+                }
             }
 
             ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
@@ -2697,7 +2801,7 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
             return false;
         }
     }
@@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     whisper_state * state = new whisper_state;
 
-    state->backend = whisper_backend_init(ctx->params);
-    if (!state->backend) {
-        WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
-        whisper_free_state(state);
-        return nullptr;
-    }
-
     // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
     // in theory, there can be a case where this is not enough, but in practice it should always be enough
     const int factor = 3;
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
+    if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_text_state,
+                ctx->model.hparams.n_text_layer,
+                GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
@@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+    if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_text_state,
+                ctx->model.hparams.n_text_layer,
+                GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
@@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
+    if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
+                ctx->model.hparams.n_audio_state,
+                1,
+                GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        whisper_free_state(state);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
+        WHISPER_LOG_INFO("%s: kv pad  size  = %7.2f MB\n", __func__, memory_size / 1e6);
+    }
+
     // [EXPERIMENTAL] Token-level timestamps with DTW
     if (ctx->params.dtw_token_timestamps) {
         if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
@@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder(
 struct whisper_context_params whisper_context_default_params() {
     struct whisper_context_params result = {
         /*.use_gpu              =*/ true,
+        /*.flash_attn           =*/ false,
         /*.gpu_device           =*/ 0,
 
         /*.dtw_token_timestamps =*/ false,
@@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
 struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
     ggml_time_init();
 
+    if (params.flash_attn && params.dtw_token_timestamps) {
+        WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
+        params.dtw_token_timestamps = false;
+    }
+
+    WHISPER_LOG_INFO("%s: use gpu    = %d\n", __func__, params.use_gpu);
+    WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
+    WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
+    WHISPER_LOG_INFO("%s: dtw        = %d\n", __func__, params.dtw_token_timestamps);
+
     whisper_context * ctx = new whisper_context;
     ctx->params = params;
 
@@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) {
     if (state) {
         kv_cache_free(state->kv_self);
         kv_cache_free(state->kv_cross);
+        kv_cache_free(state->kv_pad);
 
 #ifdef WHISPER_USE_COREML
         if (state->ctx_coreml != nullptr) {
@@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) {
         ggml_gallocr_free(state->alloc_cross.alloc);
         ggml_gallocr_free(state->alloc_decode.alloc);
 
-        ggml_backend_free(state->backend);
-
         // [EXPERIMENTAL] Token-level timestamps with DTW
         aheads_masks_free(state->aheads_masks);
 
index 6a875d3bbb9d34e67b6efe15e9eadc99511d5344..9c7c58d874b0c6be7fd7bb23ff1ba79cd0b9dc4b 100644 (file)
@@ -113,6 +113,7 @@ extern "C" {
 
     struct whisper_context_params {
         bool  use_gpu;
+        bool  flash_attn;
         int   gpu_device;  // CUDA device
 
         // [EXPERIMENTAL] Token-level timestamps with DTW