]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: use aligned loads for flash attention mask (#12853)
authorJeff Bolz <redacted>
Sat, 12 Apr 2025 08:44:48 +0000 (03:44 -0500)
committerGitHub <redacted>
Sat, 12 Apr 2025 08:44:48 +0000 (10:44 +0200)
Rewrite the stride logic for the mask tensor in the FA shader to force the
stride to be aligned, to allow using more efficient loads.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index a8f4bc41726c2fad026b836f9fab11115d97970c..e1baa85f9e33050b86e7282178cf5b8e878d7ff7 100644 (file)
@@ -201,6 +201,11 @@ void main() {
     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
     uint32_t k_stride = p.nb11;
     uint32_t v_stride = p.nb21;
+    // When using grouped query attention, all rows use the same mask (stride 0).
+    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+    // that prevents the compiler from folding the "&" through the select
+    // and breaking the alignment detection.
+    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
     // hint to the compiler that strides are aligned for the aligned variant of the shader
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {
@@ -209,6 +214,7 @@ void main() {
         k_stride &= ~7;
         v_stride &= ~7;
 #endif
+        m_stride &= ~7;
     }
     tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
         if (p.mask != 0) {
             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-            // When using grouped query attention, all rows use the same mask.
-            if (p.gqa_ratio > 1) {
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
-            }
+            tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;