From: Jeff Bolz Date: Sun, 6 Apr 2025 09:03:47 +0000 (-0500) Subject: vulkan: fix NaN issue in flash attention shader (llama/12776) X-Git-Tag: upstream/1.7.5+105~57 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3c26dd3353bced6fa88c2e7e6c7d921a1e09dfcd;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: fix NaN issue in flash attention shader (llama/12776) Use -FLT_MAX/2 rather than -inf as the initial value for computing the maximum. --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index eedbc6f8..8ddadb8a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -227,8 +227,11 @@ void main() { coopmat L, M; + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + L = coopmat(0); - M = coopmat(-1.0/0.0); + M = coopmat(NEG_FLT_MAX_OVER_2); coopmat slopeMat = coopmat(1.0); @@ -278,7 +281,7 @@ void main() { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; - coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); } coopmat rowmax, P, rowsum, eM;