]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fix NaN issue in flash attention shader (#12776)
authorJeff Bolz <redacted>
Sun, 6 Apr 2025 09:03:47 +0000 (04:03 -0500)
committerGitHub <redacted>
Sun, 6 Apr 2025 09:03:47 +0000 (11:03 +0200)
Use -FLT_MAX/2 rather than -inf as the initial value for computing the maximum.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index eedbc6f8b0e9c4e498b2a13f1fb16c445efb80c3..8ddadb8a15de5bc3aed9ffc7c9c60a037b8e165e 100644 (file)
@@ -227,8 +227,11 @@ void main() {
 
     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
 
+    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
-    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
+    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
 
     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
 
@@ -278,7 +281,7 @@ void main() {
             uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;
             uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
 
-            coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
+            coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
         }
 
         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;