]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel (#21159)
authorGaurav Garg <redacted>
Mon, 6 Apr 2026 18:34:29 +0000 (00:04 +0530)
committerGitHub <redacted>
Mon, 6 Apr 2026 18:34:29 +0000 (20:34 +0200)
* Write an optimized flash_attn_stream_k_fixup kernel

Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst.
Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst

* Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs

* Address review comments

* Address review comments

* Revert variable names to original

ggml/src/ggml-cuda/fattn-common.cuh

index c59a4db39993b7c1ab7abf812c2a23136fa1b134..beeb5238946445fe3bd46efe0a478774e4bce9a5 100644 (file)
@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
 
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
-static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
-        const int ne11, const int ne12, const int nbatch_fa) {
+static __global__ void flash_attn_stream_k_fixup_uniform(
+        float * __restrict__ dst,
+        const float2 * __restrict__ dst_fixup,
+        const int ne01, const int ne02,
+        const int ne12, const int nblocks_stream_k,
+        const int gqa_ratio,
+        const int blocks_per_tile,
+        const uint3 fd_iter_j_z_ne12,
+        const uint3 fd_iter_j_z,
+        const uint3 fd_iter_j) {
+    constexpr int ncols = ncols1*ncols2;
+
+    const int tile_idx = blockIdx.x; // One block per output tile.
+    const int j        = blockIdx.y;
+    const int c        = blockIdx.z;
+    const int jc       = j*ncols2 + c;
+    const int tid      = threadIdx.x;
+
+    // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
+    const int b_first = tile_idx * blocks_per_tile;
+    const int b_last  = b_first + blocks_per_tile - 1;
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
+
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
+    const uint2 dm1 = fast_div_modulo(dm0.y,    fd_iter_j_z);
+    const uint2 dm2 = fast_div_modulo(dm1.y,    fd_iter_j);
+
+    const int sequence = dm0.x;
+    const int z_KV     = dm1.x;
+    const int zt_gqa   = dm2.x;
+    const int jt       = dm2.y;
+
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+        return;
+    }
+
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+    // Load the partial result that needs a fixup
+    float dst_val = *dst;
+    float max_val;
+    float rowsum;
+    {
+        const float2 tmp = dst_fixup[b_last*ncols + jc];
+        max_val = tmp.x;
+        rowsum  = tmp.y;
+    }
+
+    // Combine with all previous blocks in this tile.
+    for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+        const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
+
+        const float max_val_new = fmaxf(max_val, tmp.x);
+
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x   - max_val_new;
+
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
+
+        max_val = max_val_new;
+    }
+
+    // Write back final result:
+    *dst = dst_val / rowsum;
+}
+
+// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
+// (blocks_num.x not a multiple of ntiles_dst)
+template <int D, int ncols1, int ncols2> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup_general(
+        float * __restrict__ dst,
+        const float2 * __restrict__ dst_fixup,
+        const int ne01, const int ne02,
+        const int gqa_ratio,
+        const int total_work,
+        const uint3 fd_iter_k_j_z_ne12,
+        const uint3 fd_iter_k_j_z,
+        const uint3 fd_iter_k_j,
+        const uint3 fd_iter_k) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-
-    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
-    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
-
-    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
-    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc0      = int64_t(bidx0 + 0)*total_work / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
-    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
-    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
+    const bool did_not_write_last      = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
     if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
         return;
     }
 
     // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
-    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
-    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
-    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
-    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+    const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
+    const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
+    const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
+    const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
+
+    const int sequence = dm0.x;
+    const int z_KV     = dm1.x;
+    const int zt_gqa   = dm2.x;
+    const int jt       = dm3.x;
 
     const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
 
     // Iterate over previous blocks and compute the combined results.
     // All CUDA blocks that get here must have a previous block that needs a fixup.
+    const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*total_work / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
         max_val = max_val_new;
 
         // If this block started in a previous tile we are done and don't need to combine additional partial results.
-        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+        if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
             break;
         }
         bidx--;
@@ -976,14 +1063,28 @@ void launch_fattn(
         const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
         const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
 
-        const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
-
         const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 
-        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
+        blocks_num.x = ntiles_dst;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
+        if(use_stream_k) {
+            const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
+            // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
+            // Only do this if the occupancy loss from rounding is acceptable.
+            const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
+            const int max_efficiency_loss_percent = 5;
+            const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
+                ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
+                : 100;
+            const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
+                ? nblocks_stream_k_rounded
+                : nblocks_stream_k_raw;
+
+            blocks_num.x = nblocks_stream_k;
+        }
+
         if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
         }
@@ -1063,13 +1164,40 @@ void launch_fattn(
     CUDA_CHECK(cudaGetLastError());
 
     if (stream_k) {
-        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
+            // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
+            const int nblocks_sk  = (int)blocks_num.x;
+            const int bpt         = nblocks_sk / ntiles_dst;
+
+            const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
+            const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
+            const uint3 fd2 = init_fastdiv_values(ntiles_x);
+
+            const dim3 block_dim_combine(DV, 1, 1);
+            const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
+
+            flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
+                <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+                ((float *) KQV->data, dst_tmp_meta.ptr,
+                 Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
+                 gqa_ratio, bpt, fd0, fd1, fd2);
+        } else if (ntiles_dst % blocks_num.x != 0) {
+            // General fixup for the cases where nblocks_stream_k < ntiles_dst.
+            const int total_work = ntiles_KV * ntiles_dst;
+
+            const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
+            const uint3 fd_k_j_z      = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
+            const uint3 fd_k_j        = init_fastdiv_values(ntiles_KV * ntiles_x);
+            const uint3 fd_k          = init_fastdiv_values(ntiles_KV);
+
             const dim3 block_dim_combine(DV, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
-            flash_attn_stream_k_fixup<DV, ncols1, ncols2>
+            flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
+                ((float *) KQV->data, dst_tmp_meta.ptr,
+                 Q->ne[1], Q->ne[2], gqa_ratio, total_work,
+                 fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);