From: Jeff Bolz Date: Tue, 21 Oct 2025 03:16:08 +0000 (-0500) Subject: vulkan: Handle FA with all -inf mask values (llama/16447) X-Git-Tag: upstream/0.9.4.185~102 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=a118c51aaf2a723a5facbb5af1d9730736f36cd3;p=pkg%2Fggml%2Fsources%2Fggml vulkan: Handle FA with all -inf mask values (llama/16447) --- diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 62acbf10..2255f9c1 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -345,7 +345,7 @@ void main() { float Lfrcp[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 2066a05b..8699fa6c 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -380,7 +380,7 @@ void main() { float Lfrcp[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 910da1ab..fcfc60a8 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -121,7 +121,11 @@ void main() { const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); L = coopmat(0); +#if defined(ACC_TYPE_MAX) + M = coopmat(-ACC_TYPE_MAX / ACC_TYPE(2)); +#else M = coopmat(NEG_FLT_MAX_OVER_2); +#endif coopmat slopeMat = coopmat(1.0); @@ -294,7 +298,7 @@ void main() { [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { - Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } O = Ldiag*O; diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 06e83822..4eaddd31 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -91,7 +91,7 @@ void main() { L = L*ms + vs; } - L = 1.0 / L; + L = (L == 0.0) ? 0.0 : 1.0 / L; // D dimension is split across workgroups in the y dimension uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;