]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (llama/12183)
authorGaurav Garg <redacted>
Wed, 19 Mar 2025 19:52:06 +0000 (01:22 +0530)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value.
- Prefer vector flash attention kernels over MMA kernel for BS=1

Fixes Issue: #12182
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-cuda/vendors/musa.h

index 4067fd41bc247947b788d2cd6214ae1b2d263a0f..1c2a2a138f9b3f311236ea2e9b4194d7f0267939 100644 (file)
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
     *dst = dst_val / rowsum;
 }
 
-template<int D, int parallel_blocks> // D == head size
+template<int D> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
-        float * __restrict__ dst) {
-    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
-    dst       +=                 D * gridDim.y*blockIdx.x;
+        float * __restrict__ dst,
+        const int parallel_blocks) {
+    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
+    dst       +=                 D * gridDim.z*blockIdx.x;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
-    __shared__ float2 meta[parallel_blocks];
+    extern __shared__ float2 meta[];
     if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
     }
 
     __syncthreads();
 
     float kqmax = meta[0].x;
-#pragma unroll
     for (int l = 1; l < parallel_blocks; ++l) {
         kqmax = max(kqmax, meta[l].x);
     }
 
     float VKQ_numerator   = 0.0f;
     float VKQ_denominator = 0.0f;
-#pragma unroll
     for (int l = 0; l < parallel_blocks; ++l) {
         const float diff = meta[l].x - kqmax;
         const float KQ_max_scale = expf(diff);
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
 }
 
 static void on_no_fattn_vec_case(const int D) {
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
     }
 }
 
-// parallel_blocks == 0 is stream-k decomposition
-template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
+template <int D, int ncols1, int ncols2, int KQ_stride>
 void launch_fattn(
-    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
-    const int warp_size = WARP_SIZE
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
+    const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -748,12 +745,14 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
+    int parallel_blocks = 1;
+
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(warp_size, nwarps, 1);
     dim3 blocks_num;
-    if (parallel_blocks == 0) {
+    if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
         const int max_blocks = 2*nsm;
         const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -769,9 +768,43 @@ void launch_fattn(
 
         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
     } else {
-        blocks_num.x = parallel_blocks*ntiles_x;
-        blocks_num.y = Q->ne[2];
-        blocks_num.z = Q->ne[3];
+        GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
+        const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+
+        int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+        CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+
+        // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
+        parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
+
+        // parallel_blocks must not be larger than what the tensor size allows:
+        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+
+        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+        // Test whether parallel_blocks can be set to a higher value for better efficiency.
+        const int blocks_per_wave = nsm * max_blocks_per_sm;
+        int nwaves_best = 0;
+        int efficiency_percent_best = 0;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_total * parallel_blocks_test;
+            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+            if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
+                break;
+            }
+
+            if (efficiency_percent > efficiency_percent_best) {
+                nwaves_best = nwaves;
+                efficiency_percent_best = efficiency_percent;
+                parallel_blocks = parallel_blocks_test;
+            }
+        }
+
+        blocks_num.x = ntiles_x;
+        blocks_num.y = parallel_blocks;
+        blocks_num.z = Q->ne[2]*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -803,7 +836,7 @@ void launch_fattn(
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
-        (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
+        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -815,7 +848,7 @@ void launch_fattn(
     );
     CUDA_CHECK(cudaGetLastError());
 
-    if constexpr (parallel_blocks == 0) {
+    if (stream_k) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(D, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -824,13 +857,14 @@ void launch_fattn(
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }
-    } else if constexpr (parallel_blocks > 1) {
+    } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(D, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
-        flash_attn_combine_results<D, parallel_blocks>
-            <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+        flash_attn_combine_results<D>
+            <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
+            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
     }
     CUDA_CHECK(cudaGetLastError());
 }
index 718ee5402dccd16e8a80ae8a8c57ed90827cef20..024032f6221bc303f45539f5d89184908ffedc09 100644 (file)
@@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
         fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
     }
 
-    launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
+    launch_fattn<D, ncols1, ncols2, KQ_per_iter>
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
 }
 
 
index ef3569fab27892ef9ef4cf07208dfee5d6c6136c..77455d8e4f1a256a4410cb25658d581557f684cb 100644 (file)
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
-template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
+template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half   * maskh = (const half   *)  mask + ne11*ic0;
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
 
     __syncthreads();
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         half kqmax_new[ncols/nwarps];
@@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16(
             const int i0 = i00 + 2*threadIdx.x;
 
             half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
+            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
         }
 
-        if (parallel_blocks != 1 && threadIdx.x == 0) {
-            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        if (gridDim.y != 1 && threadIdx.x == 0) {
+            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
@@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16(
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
-template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
+template <int cols_per_block, bool use_logit_softcap>
 void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
@@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    D             = 64;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
+            launch_fattn<D, cols_per_block, 1, -1>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
+            launch_fattn<D, cols_per_block, 1, -1>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
-        constexpr int parallel_blocks = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+            launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 32;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+            launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     constexpr int cols_per_block = 32;
-    constexpr int parallel_blocks = 1;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
     }
 }
index 04b69c83be0488bcae89e6583c12cd3856409f44..85fea4404d07cfc7007ddff20ad83b72c92f74cf 100644 (file)
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
-template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
+template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32(
 
     // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half   * maskh = (const half   *)  mask + ne11*ic0;
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
     __syncthreads();
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new[ncols/nwarps];
@@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32(
             const int i0 = i00 + 2*threadIdx.x;
 
             float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
+            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
+            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
         }
 
-        if (parallel_blocks != 1 && threadIdx.x == 0) {
-            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        if (gridDim.y != 1 && threadIdx.x == 0) {
+            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
@@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32(
 #endif // FLASH_ATTN_AVAILABLE
 }
 
-template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
+template <int cols_per_block, bool use_logit_softcap>
 void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
@@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    D             = 64;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
+            launch_fattn<D, cols_per_block, 1, -1>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
+            launch_fattn<D, cols_per_block, 1, -1>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
-        constexpr int parallel_blocks = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+            launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 32;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+            launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     constexpr int cols_per_block = 32;
-    constexpr int parallel_blocks = 1;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
     }
 }
index b7686c1ec3d47ef8cacf05f5487857f03cd51ed9..32c52ebe33e9c64131debdd3124d5d6dfd1ab312 100644 (file)
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
+template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
     constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
     constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.y              + nb01*ic0;
-    K += nb12*(blockIdx.y / gqa_ratio);
-    V += nb22*(blockIdx.y / gqa_ratio);
+    Q += nb02* blockIdx.z              + nb01*ic0;
+    K += nb12*(blockIdx.z / gqa_ratio);
+    V += nb22*(blockIdx.z / gqa_ratio);
 
     const half * maskh = (const half   *)  mask + ne11*ic0;
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16(
         kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
 
         half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
-        if (parallel_blocks == 1) {
+        if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
+template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
 void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
@@ -325,65 +323,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block  = 1;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block  = 2;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 2;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block  = 4;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
-    if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block  = 8;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
-        }
-        return;
-    }
-
-    constexpr int cols_per_block  = 8;
-    constexpr int parallel_blocks = 1;
+    constexpr int cols_per_block = 8;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
     }
 }
 
index c1d2dd8d19f4d8c5d53351ab4a5235d34bc878e8..336c136d19d7da6b35b734e1401546fedff28a5b 100644 (file)
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
+template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32(
     constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
     constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
 
-    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
-    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.y              + nb01*ic0;
-    K += nb12*(blockIdx.y / gqa_ratio);
-    V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+    Q += nb02* blockIdx.z              + nb01*ic0;
+    K += nb12*(blockIdx.z / gqa_ratio);
+    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
     const half * maskh = (const half   *)  mask + ne11*ic0;
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
     float VKQ[ncols] = {0.0f};
 
-    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
-    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new_arr[ncols];
@@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32(
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
         float dst_val = VKQ[j_VKQ];
-        if (parallel_blocks == 1) {
+        if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
+        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
 
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
+template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
 void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
@@ -307,65 +305,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block  = 1;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block  = 2;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 2;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block  = 4;
-        constexpr int parallel_blocks = 4;
+        constexpr int cols_per_block = 4;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         } else {
             constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
         }
         return;
     }
 
-    if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block  = 8;
-        constexpr int parallel_blocks = 4;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
-        } else {
-            constexpr bool use_logit_softcap = true;
-            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
-        }
-        return;
-    }
-
-    constexpr int cols_per_block  = 8;
-    constexpr int parallel_blocks = 1;
+    constexpr int cols_per_block = 8;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
     } else {
         constexpr bool use_logit_softcap = true;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
     }
 }
 
index dab1d5cbcace4b43fdd9aa78c6531ef92b2aacc8..5c214ea3109d207abf5d688971eb74e146bc0e63 100644 (file)
@@ -18,7 +18,7 @@ namespace wmma = rocwmma;
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
+template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
 __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
@@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+    const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
 
     static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
     static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
@@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.z              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.z / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
     const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
     const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16(
     __syncthreads();
 
     // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
         // Calculate tile of KQ:
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same<KQ_acc_t, float>::value) {
@@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16(
                 break;
             }
             float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
+            if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
         }
 
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
+        if (gridDim.y == 1 || threadIdx.x != 0) {
             continue;
         }
 
@@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
     }
 #else
    NO_DEVICE_CODE;
@@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
 template <int D, int cols_per_block, typename KQ_acc_t>
 void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
 
     constexpr int nwarps = 4;
 
     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
         fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     } else {
         constexpr bool use_logit_softcap = true;
         fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
+    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 2e72fc8fd380bb68d7cbdb2500b056925c5504a7..973541893ec21e8df9c58d5229908ac2fcda0aee 100644 (file)
@@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     if (!fp16_mma_available(cc)) {
         if (prec == GGML_PREC_DEFAULT) {
-            if (Q->ne[1] <= 8) {
+            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
             }
         } else {
-            if (Q->ne[1] <= 8) {
+            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
                 ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
@@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    const int gqa_ratio = Q->ne[2] / K->ne[2];
-    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
-        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
+    const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
+    const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
+    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
+    const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
+    if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-            return;
-        } else if(Q->ne[0] <= 128) {
+        } else {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-            return;
         }
+        return;
     }
 
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
index 5cb56df9a81ae82e97d00b21952ac2438826540d..b783310ef7ba74596a4543b5c78692603881a19c 100644 (file)
@@ -3230,6 +3230,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
 #endif // FLASH_ATTN_AVAILABLE
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
index aace21e3a8b18ff114ff4f5dceb59919580f677b..a4c717a321cfb325b7276968e07138bb1859b92a 100644 (file)
 #define cudaGraph_t hipGraph_t
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
index 997f671431e01e16a61a1daf48f38b328350d6f8..f2d55796e7874954b458fc83e6344499c266d555 100644 (file)
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
 #define cudaStreamBeginCapture musaStreamBeginCapture
 #define cudaStreamEndCapture musaStreamEndCapture
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
 
 typedef mt_bfloat16 nv_bfloat16;