]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : update FA call
authorGeorgi Gerganov <redacted>
Wed, 28 Aug 2024 08:02:54 +0000 (11:02 +0300)
committerGeorgi Gerganov <redacted>
Wed, 28 Aug 2024 10:22:20 +0000 (13:22 +0300)
src/whisper.cpp

index 0e72f87546559cab9f1f2f0d0a0ba7908c0edc3b..35874aa5abafebe2c2da5dcc70981f834e20830c 100644 (file)
@@ -2124,7 +2124,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                             ggml_element_size(kv_pad.v)*n_state_head,
                             0);
 
-                cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
 
                 cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
             } else {
@@ -2563,7 +2563,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             ggml_element_size(kv_self.v)*n_state_head,
                             ggml_element_size(kv_self.v)*n_state*n_ctx*il);
 
-                cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
+                cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
 
                 cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
             } else {
@@ -2645,7 +2645,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             ggml_element_size(wstate.kv_cross.v)*n_state_head,
                             ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
 
-                cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
+                cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
 
                 cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
             } else {