From: Jeff Bolz Date: Wed, 9 Apr 2025 05:12:57 +0000 (-0500) Subject: vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783) X-Git-Tag: upstream/1.7.5+105~44 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=1d50c6ac2262ada2d5e75dd1138b8fad3a10db15;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783) This is consistent with the ggml-cuda behavior and the mul_mat fallback. --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8ddadb8a..a8f4bc41 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -330,9 +330,11 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final