]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix padding of GQA to power of 2 in FA (llama/19115)
authorJohannes Gäßler <redacted>
Mon, 26 Jan 2026 22:24:58 +0000 (23:24 +0100)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
tests/test-backend-ops.cpp

index 1f5f1b9206cf6fb748c5e2eedda83080eefdf5dc..3d7daccfdf873b2ef3b4aac14bb742d0460f379a 100644 (file)
@@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
-        const int nbatch_fa) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
+        const int ne11, const int ne12, const int nbatch_fa) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -641,12 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
-    const int iter_z = (ne02 + (ncols2    - 1)) / ncols2;
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
-    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
-    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
+
+    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -655,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int sequence = kbc0 / (iter_k*iter_j*iter_z);
-    const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
-    const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-    if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
         return;
     }
 
-    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -682,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -883,9 +889,10 @@ void launch_fattn(
         }
     }
 
-    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
-    const int ntiles_z =  ((Q->ne[2] + ncols2 - 1) / ncols2);
-    const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
+    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int gqa_ratio    = Q->ne[2] / K->ne[2];
+    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
 
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -960,7 +967,7 @@ void launch_fattn(
 
         blocks_num.x = ntiles_x;
         blocks_num.y = parallel_blocks;
-        blocks_num.z = ntiles_z*Q->ne[3];
+        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -1014,7 +1021,7 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
index 9004d46904e9a493de6b4659f74a21d89d001335..0b8ef90794c7cc63db0b8300e7bab17d16814602 100644 (file)
@@ -933,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float logit_softcap,
         const uint3 ne01,
         const int ne02,
+        const int gqa_ratio,
         const int ne11,
         const int stride_Q1,
         const int stride_Q2,
@@ -940,7 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int stride_V,
         const int stride_mask,
         const int jt,
-        const int zt,
+        const int zt_gqa,
         const int kb0_start,
         const int kb0_stop) {
 #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1023,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int j = jc / ncols2;
             const int c = jc % ncols2;
 
-            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
+            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1409,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int j_dst = jc_dst / ncols2;
                     const int c_dst = jc_dst % ncols2;
 
-                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
+                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
                         continue;
                     }
 
@@ -1448,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     }
 #else
     GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
-        scale, slope, logit_softcap, ne01, ne02,
+        scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
@@ -1521,13 +1522,13 @@ static __global__ void flash_attn_ext_f16(
 
     const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
 
-    const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
-    const int iter_z = (ne02   + (ncols2    - 1)) / ncols2;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01.z    + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
-    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1538,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
 
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int sequence = kbc / (iter_k*iter_j*iter_z);
-        const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
-        const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+        // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+        const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+        const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+        const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+        const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-        const int head0 = zt * ncols2;
+        const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
         const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
             (const half *) (mask + nb33*(sequence % ne33));
-        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-        const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+        const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
         if (KV_max) {
             kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1563,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -1582,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int sequence = kbc / (iter_k*iter_j*iter_z);
-    const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
-    const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
+    const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-    const int head0 = zt * ncols2;
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
     const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
         (const half *) (mask + nb33*(sequence % ne33));
-    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-    const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+    const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
     if (KV_max) {
         kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1607,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
     constexpr bool needs_fixup = false;
     flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+         ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
         max_bias, m0, m1, n_head_log2, logit_softcap,
index 146d05f53bc76d3afd6a8ade8811f4257fde0965..d4c1f525c674c5cead70cb6b0983da0672696753 100644 (file)
@@ -8216,8 +8216,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                             for (int nh : { 4, }) {
                                 for (int nr3 : { 1, 3, }) {
                                     if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
-                                    for (int nr2 : { 1, 4, 16 }) {
-                                        if (nr2 == 16 && hsk != 128) continue;
+                                    for (int nr2 : { 1, 4, 12 }) {
+                                        if (nr2 == 12 && hsk != 128) continue;
                                         //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
                                         for (int kv : { 113, 512, 1024, }) {
                                             if (nr2 != 1 && kv != 512) continue;