From: Georgi Gerganov Date: Wed, 28 Aug 2024 08:02:54 +0000 (+0300) Subject: whisper : update FA call X-Git-Tag: upstream/1.7.4~458 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9d754a56cf628cb5e487bc134078b4b04ee14b98;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp whisper : update FA call --- diff --git a/src/whisper.cpp b/src/whisper.cpp index 0e72f875..35874aa5 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -2124,7 +2124,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_element_size(kv_pad.v)*n_state_head, 0); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); } else { @@ -2563,7 +2563,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_element_size(kv_self.v)*n_state_head, ggml_element_size(kv_self.v)*n_state*n_ctx*il); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); } else { @@ -2645,7 +2645,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_element_size(wstate.kv_cross.v)*n_state_head, ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); - cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f); + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); } else {