]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Implement grouped query attention in the coopmat2 FA shader (llama/12559)
authorJeff Bolz <redacted>
Wed, 2 Apr 2025 17:40:32 +0000 (12:40 -0500)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
When adjacent batches of Q share the same batches of K/V, batch them into
the same workgroup. For example, when:

dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1))

previously we would run 32 workgroups computing 1 result each, now we will
run 8 workgroups computing 4 results each.

This doesn't directly translate to better performance (at least when you have
>=32 SMs), but in a subsequent change I'll enable split_k which will scale much
better with 4x fewer workgroups.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index ee0969fe189b4879f3a8351df44989fbde2e28f3..f60fe33aae18c94e005ee6c9fc8e8e3eb48df611 100644 (file)
@@ -31,6 +31,7 @@
 
 #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
 #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
+static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
 
 #define VK_VENDOR_ID_AMD 0x1002
 #define VK_VENDOR_ID_APPLE 0x106b
@@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants {
     uint32_t n_head_log2;
     float m0;
     float m1;
+
+    uint32_t gqa_ratio;
 };
 
 struct vk_op_push_constants {
@@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     const uint32_t nbm1 = mask ? mask->nb[1] : 0;
 
     const uint32_t D = neq0;
-    const uint32_t N = neq1;
+    uint32_t N = neq1;
     const uint32_t KV = nek1;
 
     GGML_ASSERT(ne0 == D);
@@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_pipeline pipeline = pipelines[aligned];
     assert(pipeline);
 
+    uint32_t gqa_ratio = 1;
+    uint32_t qk_ratio = neq2 / nek2;
+    uint32_t workgroups_x = (uint32_t)neq1;
+    uint32_t workgroups_y = (uint32_t)neq2;
+    uint32_t workgroups_z = (uint32_t)neq3;
+
+    if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
+        qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
+        // grouped query attention - make the N dimension equal to gqa_ratio, reduce
+        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
+        // and change addressing calculations to index Q's dimension 2.
+        gqa_ratio = qk_ratio;
+        N = gqa_ratio;
+        workgroups_y /= N;
+    }
+
     if (dryrun) {
         // Request descriptor sets
         ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
@@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                               v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
                                               nbm1,
                                               scale, max_bias, logit_softcap,
-                                              mask != nullptr, n_head_log2, m0, m1 };
+                                              mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
                                 {
                                     vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                     vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
                                     vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
                                 },
-                                sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
+                                sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
 }
 
 static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
index df30355f635b8e8a1d16f7814b9019ab62c9ff31..cac8f107b5d749fad550b52dc39bbaee773bec78 100644 (file)
@@ -61,6 +61,8 @@ layout (push_constant) uniform parameter {
     uint32_t n_head_log2;
     float m0;
     float m1;
+
+    uint32_t gqa_ratio;
 } p;
 
 layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
 #define DECODEFUNC
 #endif
 
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    if (r < N && c < D) {
+        uint32_t offset = (iq2 + r) * D + c;
+        data_o[o_offset + offset] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+    const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
+
+    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
@@ -116,7 +140,9 @@ void main() {
 
     const uint32_t i = gl_WorkGroupID.x;
 
-    const uint32_t iq2 = gl_WorkGroupID.y;
+    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
     const uint32_t iq3 = gl_WorkGroupID.z;
 
     // broadcast factors
@@ -149,8 +175,10 @@ void main() {
     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
 
-    // nb?1 are already divided by the type size and are in units of elements
-    uint32_t q_stride = p.nb01;
+    // nb?1 are already divided by the type size and are in units of elements.
+    // When using grouped query attention, Q is indexed by iq2, so the stride
+    // should be nb02 (which is in bytes).
+    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
     uint32_t k_stride = p.nb11;
     uint32_t v_stride = p.nb21;
     // hint to the compiler that strides are aligned for the aligned variant of the shader
@@ -182,16 +210,11 @@ void main() {
     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
     M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
 
-    ACC_TYPE slope = ACC_TYPE(1.0);
+    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
 
     // ALiBi
     if (p.max_bias > 0.0f) {
-        const uint32_t h = iq2;
-
-        const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
-        const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
-
-        slope = pow(base, ACC_TYPE(exph));
+        coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
     }
 
     [[dont_unroll]]
@@ -215,12 +238,16 @@ void main() {
         if (p.mask != 0) {
             tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+            // When using grouped query attention, all rows use the same mask.
+            if (p.gqa_ratio > 1) {
+                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
+            }
 
             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
 
             coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
 
-            S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
         }
 
         // Clear padding elements to -inf, so they don't contribute to rowmax
@@ -297,13 +324,18 @@ void main() {
 
     O = Ldiag*O;
 
-    tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
-    tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
-
-    // permute dimensions
-    tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
     uint32_t o_offset = iq3*p.ne2*p.ne1;
 
     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
-    coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
+    if (p.gqa_ratio > 1) {
+        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+    } else {
+        tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
+        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+
+        // permute dimensions
+        tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
+
+        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
+    }
 }