]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: restore -inf check in FA shaders (llama/19582)
authorJeff Bolz <redacted>
Fri, 13 Feb 2026 19:35:29 +0000 (11:35 -0800)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index e5dcd3cbda276cb02bc729f524fb04c5642da59d..82933ae03301cb397b20fa4e159d09bcc2fea7c9 100644 (file)
@@ -8422,6 +8422,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
+    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
+
     const uint32_t qstride = hsk_pad / 4 + 2;
     const uint32_t Qf = Br * qstride * f16vec4;
 
@@ -8438,7 +8440,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
 
     const uint32_t slope = Br * acctype;
 
-    const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
+    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
index 914f131c96518adc01aa5b096bbd6ef23281c52b..0735f678549aaa1f1eeb54e9822055bc3195071a 100644 (file)
@@ -130,6 +130,7 @@ void main() {
         if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
+            float max_mask = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
@@ -137,12 +138,25 @@ void main() {
                     if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
                         float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
                         masksh[c][r] = m;
+                        max_mask = max(max_mask, m);
                     } else {
                         masksh[c][r] = float(0);
                     }
                 }
             }
+            // skip the block if the mask is entirely -inf
+            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
             barrier();
+            if (gl_SubgroupInvocationID == 0) {
+                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+            }
+            barrier();
+            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                max_mask = max(max_mask, tmpsh[s]);
+            }
+            if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                continue;
+            }
         }
 
         float Sf[Br][cols_per_thread];
@@ -260,6 +274,9 @@ void main() {
         barrier();
     }
 
+    // prevent race on tmpsh
+    barrier();
+
     // reduce across threads
 
     [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
index b31777382341a5c5c79cb04f67faf60be899f9c5..19630972dafb56f47d76d8084f131ed4deb42611 100644 (file)
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
+shared float tmpsh[row_split];
+
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
@@ -213,6 +215,19 @@ void main() {
                         }
                     }
                 }
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
             }
         }
 
index 39f0c4d23b9ffa3b342fdff4f26cf414fa70374e..853f17fa16ee136257f44863fd0ee7d6bbb6d75d 100644 (file)
@@ -176,7 +176,14 @@ void main() {
                     tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
                     tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
 
+                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
+
                     coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
                 } else {
                     tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
                     // Don't clamp against nem1 when GQA is enabled
@@ -184,7 +191,14 @@ void main() {
                     tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
                     tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
+                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
+
                     coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
                 }
             }
         }