]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: limit number of FA stream-k CUDA blocks (llama/20586)
authorJohannes Gäßler <redacted>
Sun, 15 Mar 2026 17:30:47 +0000 (18:30 +0100)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
src/ggml-cuda/fattn-common.cuh

index b6a7460da831dc8e28e9ff5a0c175ef12501de5f..e9abdf288c43f42db533b9457e1786342ecafb07 100644 (file)
@@ -892,7 +892,7 @@ void launch_fattn(
     const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int gqa_ratio    = Q->ne[2] / K->ne[2];
     const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
-    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
+    const int ntiles_dst   = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
 
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -919,37 +919,37 @@ void launch_fattn(
     GGML_ASSERT(max_blocks_per_sm > 0);
     int parallel_blocks = max_blocks_per_sm;
 
+    const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
+
     dim3 blocks_num;
     if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
         const int max_blocks = max_blocks_per_sm*nsm;
-        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
-        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+        const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
 
-        const int nblocks_stream_k = max_blocks;
+        const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
 
         const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 
-        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
         }
     } else {
-        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
-
         // parallel_blocks must not be larger than what the tensor size allows:
-        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+        parallel_blocks = std::min(parallel_blocks, ntiles_KV);
 
         // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
         // Test whether parallel_blocks can be set to a higher value for better efficiency.
         const int blocks_per_wave = nsm * max_blocks_per_sm;
         int nwaves_best = 0;
         int efficiency_percent_best = 0;
-        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
-            const int nblocks_total = ntiles_total * parallel_blocks_test;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_dst * parallel_blocks_test;
             const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
             const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
 
@@ -1015,7 +1015,7 @@ void launch_fattn(
     CUDA_CHECK(cudaGetLastError());
 
     if (stream_k) {
-        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(DV, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};