]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD...
authorJeff Bolz <redacted>
Fri, 3 Oct 2025 08:33:08 +0000 (03:33 -0500)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 04:57:25 +0000 (07:57 +0300)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 003a9010674dd5f0f0130dab215215113d1f0c46..def8dc96d2f9a9def7a5aca41b32fbd0ecb3a685 100644 (file)
@@ -2614,8 +2614,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
         const uint32_t D_lsb = D ^ (D & (D-1));
         uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
 
-        // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
-        GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
         return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
     };
 
@@ -7457,8 +7455,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
         aligned = false;
     }
-    // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
-    GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
 
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 
index 43b906e5ed96ddfadf3823c3c6454eedf24ef854..e42475026a92ca115c9b040e107c1ba27595d22d 100644 (file)
@@ -153,12 +153,13 @@ void main() {
         }
 
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br) {
-                    if (!KV_bounds_check || j * Bc + c < KV) {
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
                         masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
                     } else {
                         masksh[c][r] = float(0);
index ddb1246e0ba7c77ab8047a15ee07ad769eb8d28d..e76dbb4deca369eca80f3a122bcc98950f31b2bd 100644 (file)
@@ -201,11 +201,13 @@ void main() {
         }
 
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if (!KV_bounds_check || j * Bc + c < KV) {
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
                         sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
                     }
                 }
index ab647e9bc8b688407db6d8d4a1e5027658c381f7..a65553a481a2ca99bb93203fcaaf23ba666179a7 100644 (file)
@@ -154,15 +154,31 @@ void main() {
         }
 
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
-            tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-            tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-            coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+            if (nem1_bounds_check) {
+                tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
-            coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
 
-            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+                S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+            } else {
+                tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+                // Don't clamp against nem1 when GQA is enabled
+                uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+
+                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+                S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+            }
         }
 
         // Clear padding elements to -inf, so they don't contribute to rowmax