]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fix OOB check in flash_attn_mask_opt (#20296)
authorJeff Bolz <redacted>
Thu, 12 Mar 2026 05:35:49 +0000 (00:35 -0500)
committerGitHub <redacted>
Thu, 12 Mar 2026 05:35:49 +0000 (06:35 +0100)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp

index 8807c3e2b6efc36a3126b3fa3c1e42245bc5955b..6574955cf103dfa987c7cccb33380b7236efc8a7 100644 (file)
@@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 
     // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
-    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
+    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
     vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
                                                                    mask != nullptr, use_mask_opt, logit_softcap != 0);
 
index 8c92c1adcda8ff8cb9eeca6238cd4ff4909ec249..0e417708062b75fd6d6a00e3cd749b763cbc69ec 100644 (file)
@@ -33,6 +33,61 @@ layout (push_constant) uniform parameter {
 shared float minsh[NUM_SUBGROUPS];
 shared float maxsh[NUM_SUBGROUPS];
 
+float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
+    const uint tid = gl_LocalInvocationIndex;
+
+    [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+        float min_v = FLT_MAX_OVER_2;
+        float max_v = -FLT_MAX_OVER_2;
+        [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+            uint j0 = (i + tid) % (Bc / 4);
+            uint j1 = (i + tid) / (Bc / 4);
+
+            j0 *= 4;
+            j0 += (i0 * 16 + block_x) * Bc;
+            j1 += i1 * Br;
+
+            if (!need_bounds_check || j0 + 3 < nem0) {
+                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    min_v = min(min_v, f[c]);
+                    max_v = max(max_v, f[c]);
+                }
+            } else {
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    if (j0 + c < nem0) {
+                        float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+                        min_v = min(min_v, f);
+                        max_v = max(max_v, f);
+                    }
+                }
+            }
+        }
+        min_v = subgroupMin(min_v);
+        max_v = subgroupMax(max_v);
+        if (gl_SubgroupInvocationID == 0) {
+            minsh[gl_SubgroupID] = min_v;
+            maxsh[gl_SubgroupID] = max_v;
+        }
+        barrier();
+        if (tid == 0) {
+            [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                min_v = min(min_v, minsh[i]);
+                max_v = max(max_v, maxsh[i]);
+            }
+            if (max_v <= -FLT_MAX_OVER_2) {
+                result |= 1 << (2*block_x);
+            }
+            if (min_v == 0.0f && max_v == 0.0f) {
+                result |= 2 << (2*block_x);
+            }
+        }
+        barrier();
+    }
+}
+
 // For each Br x Bc block of the mask (input) buffer, read all values and check
 // if it's all -inf or all zero. Write out a two-bit code indicating which it is
 // (or zero for neither). Each workgroup processes 16 tiles and writes out a
@@ -48,50 +103,15 @@ void main() {
     const uint i2 = gl_WorkGroupID.z % nem2;
     const uint i3 = gl_WorkGroupID.z / nem2;
 
-    float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
-
     uint result = 0;
 
     // Fast path for fully in-bounds blocks where we can do f16vec4 loads
     if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
         ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
-        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
-            float min_v = FLT_MAX_OVER_2;
-            float max_v = -FLT_MAX_OVER_2;
-            [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
-                uint j0 = (i + tid) % (Bc / 4);
-                uint j1 = (i + tid) / (Bc / 4);
-
-                j0 *= 4;
-                j0 += (i0 * 16 + block_x) * Bc;
-                j1 += i1 * Br;
-
-                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
-                [[unroll]] for (int c = 0; c < 4; ++c) {
-                    min_v = min(min_v, f[c]);
-                    max_v = max(max_v, f[c]);
-                }
-            }
-            min_v = subgroupMin(min_v);
-            max_v = subgroupMax(max_v);
-            if (gl_SubgroupInvocationID == 0) {
-                minsh[gl_SubgroupID] = min_v;
-                maxsh[gl_SubgroupID] = max_v;
-            }
-            barrier();
-            if (tid == 0) {
-                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
-                    min_v = min(min_v, minsh[i]);
-                    max_v = max(max_v, maxsh[i]);
-                }
-                if (max_v <= -FLT_MAX_OVER_2) {
-                    result |= 1 << (2*block_x);
-                }
-                if (min_v == 0.0f && max_v == 0.0f) {
-                    result |= 2 << (2*block_x);
-                }
-            }
-            barrier();
+        if ((i0 + 1) * 16 * Bc <= nem0) {
+            loadvec4(result, i0, i1, i2, i3, false);
+        } else {
+            loadvec4(result, i0, i1, i2, i3, true);
         }
     } else {
         [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {