]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fix flash attention dot product precision (llama/20589)
authorRuben Ortlam <redacted>
Mon, 16 Mar 2026 09:45:49 +0000 (10:45 +0100)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
src/ggml-vulkan/vulkan-shaders/flash_attn.comp

index ec48f5b11528576018e89c50f1e28cefec1974ec..11b7dce8578546094104f4be499e05411ce8e58a 100644 (file)
@@ -245,7 +245,7 @@ void main() {
 #endif
                     }
                     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                        Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
+                        Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
                     }
                 }
             }
@@ -270,7 +270,7 @@ void main() {
 #endif
                     }
                     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                        Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
+                        Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
                     }
                 }
             }