From: Ruben Ortlam Date: Mon, 16 Mar 2026 09:45:49 +0000 (+0100) Subject: vulkan: fix flash attention dot product precision (#20589) X-Git-Tag: upstream/0.0.8611~238 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=46dba9fce860d41ac545224623f27ac71f9d264a;p=pkg%2Fggml%2Fsources%2Fllama.cpp vulkan: fix flash attention dot product precision (#20589) --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec48f5b11..11b7dce85 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -245,7 +245,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); } } } @@ -270,7 +270,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); } } }