]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Use fp16 for the flash attention P*V multiplication (#12783)
authorJeff Bolz <redacted>
Wed, 9 Apr 2025 05:12:57 +0000 (00:12 -0500)
committerGitHub <redacted>
Wed, 9 Apr 2025 05:12:57 +0000 (07:12 +0200)
This is consistent with the ggml-cuda behavior and the mul_mat fallback.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 8ddadb8a15de5bc3aed9ffc7c9c60a037b8e165e..a8f4bc41726c2fad026b836f9fab11115d97970c 100644 (file)
@@ -330,9 +330,11 @@ void main() {
         // resize eM by using smear/reduce
         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
-        O = eMdiag * O;
+        // multiply with fp16 accumulation, then add to O.
+        coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+        PV = coopMatMulAdd(P_A, V, PV);
 
-        O = coopMatMulAdd(P_A, V, O);
+        O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
     }
 
     // If there is split_k, then the split_k resolve shader does the final