]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Use unclamped loads for flash attention mask (llama/12720)
authorJeff Bolz <redacted>
Sun, 6 Apr 2025 08:47:13 +0000 (03:47 -0500)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple
of the number of rows in the matrix. The KV dim is a multiple of the number of
columns for the aligned shader.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index f3c24e503da2cc162b75e43c1d961b477848f60b..705a6135a658418d01f3b83378f6b5406fed1c88 100644 (file)
@@ -1833,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
             // can't use 256 for D==80.
             uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
             auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
+            // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
+            GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
             return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
         };
 
@@ -5511,6 +5513,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                    // the "aligned" shader variant will forcibly align strides, for performance
                    (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
 
+    // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
+    GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
+
     vk_pipeline pipeline = pipelines[aligned];
     assert(pipeline);
 
index d78092000d83981e151faa0952e754789053e515..eedbc6f8b0e9c4e498b2a13f1fb16c445efb80c3 100644 (file)
@@ -256,7 +256,7 @@ void main() {
         }
 
         if (p.mask != 0) {
-            tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+            tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
             // When using grouped query attention, all rows use the same mask.
             if (p.gqa_ratio > 1) {