From: Ruben Ortlam Date: Mon, 16 Mar 2026 09:45:49 +0000 (+0100) Subject: vulkan: fix flash attention dot product precision (llama/20589) X-Git-Tag: v0.9.9~58 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ec47d8e5d547c0812185f017ec2ed7d4850e2a24;p=pkg%2Fggml%2Fsources%2Fggml vulkan: fix flash attention dot product precision (llama/20589) --- diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec48f5b1..11b7dce8 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -245,7 +245,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); } } } @@ -270,7 +270,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); } } }