]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Handle FA with all -inf mask values (llama/16447)
authorJeff Bolz <redacted>
Tue, 21 Oct 2025 03:16:08 +0000 (22:16 -0500)
committerGeorgi Gerganov <redacted>
Tue, 21 Oct 2025 15:14:33 +0000 (18:14 +0300)
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

index 62acbf107a2984a6aaa944e409e9ed0227fc011c..2255f9c168e6eba9b72b99ca5065bc6be7a8d073 100644 (file)
@@ -345,7 +345,7 @@ void main() {
 
     float Lfrcp[Br];
     [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        Lfrcp[r] = 1.0 / Lf[r];
+        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
index 2066a05b34902491ab5782a83070da5aa8a21750..8699fa6c9cbb7dfd6758eeaaca4fab47ee11be61 100644 (file)
@@ -380,7 +380,7 @@ void main() {
 
     float Lfrcp[rows_per_thread];
     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        Lfrcp[r] = 1.0 / Lf[r];
+        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
index 910da1ab0c28f33ca7f4cf032c23050d7ab9d833..fcfc60a8785444c79764f89a593bcba1d8d59dcd 100644 (file)
@@ -121,7 +121,11 @@ void main() {
     const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
 
     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+#if defined(ACC_TYPE_MAX)
+    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
+#else
     M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
+#endif
 
     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
 
@@ -294,7 +298,7 @@ void main() {
 
     [[unroll]]
     for (int k = 0; k < Ldiag.length(); ++k) {
-        Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
+        Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
     }
 
     O = Ldiag*O;
index 06e83822fe326929bc4c5af164679307bb0d954d..4eaddd31a8f58b880fb841bf8d42f6a4cddffd40 100644 (file)
@@ -91,7 +91,7 @@ void main() {
         L = L*ms + vs;
     }
 
-    L = 1.0 / L;
+    L = (L == 0.0) ? 0.0 : 1.0 / L;
 
     // D dimension is split across workgroups in the y dimension
     uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;