]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: optimize FA for GQA + large batches (#12014)
authorJohannes Gäßler <redacted>
Sat, 22 Feb 2025 11:20:17 +0000 (12:20 +0100)
committerGitHub <redacted>
Sat, 22 Feb 2025 11:20:17 +0000 (12:20 +0100)
32 files changed:
ggml/src/ggml-cuda/cp-async.cuh
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
tests/test-backend-ops.cpp

index 51aa41e7e60ef0d0287823405f55a8f8fef45d6e..ecb659997ba516278235ece913905f1dcb3ade5b 100644 (file)
@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
     } else
 #endif // CUDART_VERSION >= 11040
     {
-        asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
+        asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
             : : "r"(dst), "l"(src));
     }
 #else
index fefbd319baf76ff1637aa400215aee7efb2e2884..7b9566fb4be3268b89e4bde50c10c42c7c491841 100644 (file)
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
-// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wpass-failed"
-#endif // __clang__
-
-template<int D, int ncols, int KQ_stride> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_stream_k_fixup(
         float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
-    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
-
-    const int iter_k = ne11 / KQ_stride;
-    const int iter_j = (ne01 + (ncols - 1)) / ncols;
+    constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
+    const int j     = blockIdx.y;
+    const int c     = blockIdx.z;
+    const int jc    = j*ncols2 + c;
+    const int tid   = threadIdx.x;
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
     const int channel = kbc0 / (iter_k*iter_j);
     const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
 
-    dst += jt*ncols*ne02*D + channel*D;
+    if (jt*ncols1 + j >= ne01) {
+        return;
+    }
 
-    // Load the partial result that needs a fixup:
-    float dst_val[ncols] = {0.0f};
-    float max_val[ncols] = {0.0f};
-    float rowsum[ncols]  = {0.0f};
-#pragma unroll
-    for (int j = 0; j < ncols; ++j) {
-        if (jt*ncols + j >= ne01) {
-            break;
-        }
-        dst_val[j] = dst[j*ne02*D + threadIdx.x];
+    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
 
-        const float2 tmp = dst_fixup[bidx0*ncols + j];
-        max_val[j] = tmp.x;
-        rowsum[j]  = tmp.y;
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const float2 tmp = dst_fixup[bidx0*ncols + jc];
+        max_val = tmp.x;
+        rowsum  = tmp.y;
     }
 
     // Iterate over previous blocks and compute the combined results.
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
+        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
             continue;
         }
 
-#pragma unroll
-        for (int j = 0; j < ncols; ++j) {
-            if (jt*ncols + j >= ne01) {
-                break;
-            }
-            const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
 
-            const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
+        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
 
-            // Scale the current and new value accumulators depending on the max. values.
-            const float max_val_new = fmaxf(max_val[j], tmp.x);
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = fmaxf(max_val, tmp.x);
 
-            const float diff_val = max_val[j] - max_val_new;
-            const float diff_add = tmp.x      - max_val_new;
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x   - max_val_new;
 
-            const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
-            const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
 
-            dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
-            rowsum[j]  = scale_val*rowsum[j]  + scale_add*tmp.y;
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
 
-            max_val[j] = max_val_new;
-        }
+        max_val = max_val_new;
 
         // If this block started in a previous tile we are done and don't need to combine additional partial results.
         if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
     }
 
     // Write back final result:
-#pragma unroll
-    for (int j = 0; j < ncols; ++j) {
-        if (jt*ncols + j >= ne01) {
-            return;
-        }
-        dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
-    }
+    *dst = dst_val / rowsum;
 }
 
-#ifdef __clang__
-#pragma clang diagnostic pop
-#endif // __clang__
-
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
 }
 
 // parallel_blocks == 0 is stream-k decomposition
-template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
+template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
     const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 ) {
+    constexpr int ncols = ncols1 * ncols2;
+
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -763,25 +747,26 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
-    const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
-    const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
+    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(WARP_SIZE, nwarps, 1);
     dim3 blocks_num;
     if (parallel_blocks == 0) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
-        const int tiles_nwaves  = (ntiles_total + 2*nsm - 1) / (2*nsm);
-        const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
+        const int max_blocks = 2*nsm;
+        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
 
-        const int nblocks_stream_k = 2*nsm;
+        const int nblocks_stream_k = max_blocks;
 
-        const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
 
         blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
     } else {
         blocks_num.x = parallel_blocks*ntiles_x;
         blocks_num.y = Q->ne[2];
@@ -793,7 +778,6 @@ void launch_fattn(
         }
     }
 
-
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
@@ -832,9 +816,9 @@ void launch_fattn(
     if constexpr (parallel_blocks == 0) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(D, 1, 1);
-            const dim3 blocks_num_combine = blocks_num;
+            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
-            flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
+            flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }
index d777f5413ed97277a0edaab46a3b8a188b8b646d..b2e0db9a2cc254ecf318da319d4e76bfec446f77 100644 (file)
@@ -5,12 +5,15 @@
 
 using namespace ggml_cuda_mma;
 
-typedef tile<16, 8, half2> tile_A;
-typedef tile< 8, 8, half2> tile_B;
-typedef tile<16, 8, float> tile_C_KQ;
-typedef tile<16, 4, half2> tile_C_VKQ;
-
-template<int D, int nwarps, int KQ_stride>
+typedef tile<16,  8, half2> tile_A;
+typedef tile< 8,  8, half2> tile_B;
+typedef tile<16,  8, half2> tile_B_16;
+typedef tile<16,  8, float> tile_C_KQ;
+typedef tile<16, 16, float> tile_C_KQ_16;
+typedef tile<16,  4, half2> tile_C_VKQ;
+typedef tile<16,  8, half2> tile_C_VKQ_16;
+
+template<int D, int nwarps, int KQ_per_iter>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
     constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
@@ -27,7 +30,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
     constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
     constexpr int stride_i = WARP_SIZE / chunks_per_row;
 #pragma unroll
-    for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+    for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
         const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
         const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
 
@@ -40,7 +43,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
     // If D is not a power of 2, the rest is loaded synchronously.
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
-    static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
+    static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
 #pragma unroll
     for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
         const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
@@ -52,7 +55,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         }
 
 #pragma unroll
-        for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+        for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
             const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
 #pragma unroll
@@ -65,12 +68,54 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
     }
 }
 
-template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int ncols1, int nwarps, int KQ_per_iter>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
+        const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
+    static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
+#ifdef CP_ASYNC_AVAILABLE
+    constexpr int preload = KQ_per_iter * sizeof(half);
+    constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+
+    const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp +
+            (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
+
+        cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+    }
+#else
+    constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
+
+        tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
+    }
+#endif // CP_ASYNC_AVAILABLE
+}
+
+template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
-        const half   * const __restrict__ maskh,
+        const half2  * const __restrict__ mask_h2,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
@@ -78,42 +123,60 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float logit_softcap,
         const int ne01,
         const int ne02,
-        const int stride_Q,
         const int stride_KV,
         const int stride_mask,
         const int jt,
         half2        * const __restrict__ tile_K,
         half2        * const __restrict__ tile_V,
+        half2        * const __restrict__ tile_mask,
         const tile_B * const __restrict__ Q_B,
         tile_C_VKQ   * const __restrict__ VKQ_C,
-        float2 & KQ_max,
-        float2 & KQ_rowsum,
+        float        * const __restrict__ KQ_max,
+        float        * const __restrict__ KQ_rowsum,
         const int kb0) {
 #ifdef NEW_MMA_AVAILABLE
-    constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
-    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int D2_padded       = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    const int k_VKQ_0 = kb0 * KQ_per_iter;
+    tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
 
-    const int k_VKQ_0 = kb0*KQ_stride;
-    tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
+    // Use wide variants of tiles if ntiles >= 2.
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+    tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
 
 #ifdef CP_ASYNC_AVAILABLE
     cp_async_wait_all();
     __syncthreads();
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
 #else
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
     __syncthreads();
 #endif // CP_ASYNC_AVAILABLE
 
     // Calculate tile of KQ:
 #pragma unroll
-    for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
+    for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
         const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
 #pragma unroll
         for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
             tile_A K_A;
             load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
-            mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
+            if (ntiles == 1) {
+                mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of KQ_C is column-major => swap A and B.
+                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
+                }
+            }
         }
     }
 
@@ -122,9 +185,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // CP_ASYNC_AVAILABLE
 
     if (use_logit_softcap) {
-        static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
+        for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
 #pragma unroll
             for (int l = 0; l < tile_C_KQ::ne; ++l) {
                 KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
@@ -132,109 +195,209 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
     }
 
-    if (maskh) {
-        static_assert(KQ_stride % (np       *tile_C_KQ::I) == 0, "bad loop size");
-        static_assert(ncols     % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
+    float KQ_max_new[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max_new[col] = KQ_max[col];
+    }
+    float KQ_rowsum_add[cols_per_thread] = {0.0f};
+
+    if (ntiles == 1) {
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                    const int i = i0 + tile_C_KQ::get_i(l);
+                    const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+
+                    KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
+                        __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+            }
+        }
+
+        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
 #pragma unroll
-        for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
-            const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 16; offset >= 4; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
+
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
 #pragma unroll
             for (int l = 0; l < tile_C_KQ::ne; ++l) {
-                const int i = i0 + tile_C_KQ::get_i(l);
-                const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
+                KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
 
-                KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
+                KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+            }
+        }
+    } else { // ntiles > 1
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
+                        const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
+                        const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+
+                        const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
+                        const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
+                        KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
+                        KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
+                    }
+                }
             }
         }
-    }
 
-    // Calculate softmax for each KQ column using the current max. value.
-    // The divisor is stored in KQ_rowsum and will be applied at the end.
-    float2 KQ_max_new = KQ_max;
-    static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
 #pragma unroll
-    for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
+            for (int t = 0; t < ntiles/2; ++t) {
 #pragma unroll
-        for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
-            KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
-            KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
+                    KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+                }
+            }
         }
-    }
 
-    // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+        // Values per KQ column are spread across 4 threads, does not need full warp reduce:
 #pragma unroll
-    for (int offset = 16; offset > 2; offset >>= 1) {
-        KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
-        KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
-    }
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 2; offset >= 1; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
 
-    float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
-    static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
 #pragma unroll
-    for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
+            for (int t = 0; t < ntiles/2; ++t) {
 #pragma unroll
-        for (int l = 0; l < tile_C_KQ::ne; ++l) {
-            const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
-            const float diff = KQ_C[k].x[l] - KQ_max_l;
-            KQ_C[k].x[l] = expf(diff);
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
 
-            if (l % 2 == 0) {
-                KQ_rowsum_add.x += KQ_C[k].x[l];
-            } else {
-                KQ_rowsum_add.y += KQ_C[k].x[l];
+                    KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
+
+                    KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+                }
             }
         }
     }
 
     {
-        const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
-        const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
-        KQ_max = KQ_max_new;
+        float KQ_max_scale[cols_per_thread];
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+            KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+            KQ_max[col] = KQ_max_new[col];
 
-        // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-        KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
-        KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
+            // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
+        }
 
-        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
+        if (ntiles == 1) {
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
 #pragma unroll
-        for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
+            for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
 #pragma unroll
-            for (int l = 0; l < tile_C_VKQ::ne; ++l) {
-                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                }
+            }
+        } else {
+#pragma unroll
+            for (int col = 0; col < cols_per_thread; ++col) {
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+                for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    }
+                }
             }
         }
     }
 
     // Convert KQ C tiles into B tiles for VKQ calculation:
-    tile_B B[KQ_stride/(np*2*tile_B::J)];
-    static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
+    tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
+    tile_B_16 * B_16 = (tile_B_16 *) B;
+    static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
+    if (ntiles == 1) {
 #pragma unroll
-    for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
-        B[k] = get_transposed(get_half2(KQ_C[k]));
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
+            B[k] = get_transposed(get_half2(KQ_C[k]));
+        }
+    } else {
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+                B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
+            }
+        }
     }
 
 #ifdef CP_ASYNC_AVAILABLE
+    // Preload K tile for next iteration:
     cp_async_wait_all();
     __syncthreads();
     if (!last_iter) {
-        flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
+        if (ncols2 > 1 || mask_h2) {
+            flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
+        }
+        flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
     }
 #else
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
     __syncthreads();
 #endif // CP_ASYNC_AVAILABLE
 
     // Calculate VKQ tile:
 #pragma unroll
     for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
-        static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
+        static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
 #pragma unroll
-        for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
+        for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
             const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
 
             tile_A A;
             load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
-            mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+            if (ntiles == 1) {
+                mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of VKQ_C is column-major => swap A and B.
+                    mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+                }
+            }
         }
     }
 
@@ -247,12 +410,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
-        const half   * const __restrict__ maskh,
+        const half2  * const __restrict__ mask_h2,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
@@ -260,7 +423,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float logit_softcap,
         const int ne01,
         const int ne02,
-        const int stride_Q,
+        const int stride_Q1,
+        const int stride_Q2,
         const int stride_KV,
         const int stride_mask,
         const int jt,
@@ -269,63 +433,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #ifdef NEW_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
-    constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
+    constexpr int ncols           = ncols1 * ncols2;
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+
+    static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
 
-    static_assert(D         % nwarps == 0, "bad D");
-    static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+    static_assert(D           % nwarps == 0, "bad D");
+    static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
 
     constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
-    // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
+    // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
     extern __shared__ half2 tile_K[];
 #ifdef CP_ASYNC_AVAILABLE
-    half2 * tile_V = tile_K + KQ_stride*D2_padded;
+    half2 * tile_V    = tile_K + KQ_per_iter*D2_padded;
 #else
-    half2 * tile_V = tile_K;
+    half2 * tile_V    = tile_K;
 #endif // CP_ASYNC_AVAILABLE
+    half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
 
-    tile_B Q_B[D/(2*tile_B::J)];
-    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
+    tile_B       Q_B[D/(2*tile_B::J) * ntiles];
+    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
 
-    float2 KQ_rowsum = {0.0f, 0.0f};
-    float2    KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+
+    float KQ_rowsum[cols_per_thread] = {0.0f};
+    float KQ_max[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max[col] = -FLT_MAX/2.0f;
+    }
 
     // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
     // The loading is done with decreasing granularity for D for better memory bandwidth.
     const half2 scale_h2 = make_half2(scale, scale);
 #pragma unroll
     for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-        const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
-        const int stride_j = WARP_SIZE / stride_k;
+        const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+        const int stride_jc = WARP_SIZE / stride_k;
 
         if (k0_start == k0_stop) {
             continue;
         }
 
-        if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
-            break;
-        }
-
 #pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
-            const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+        for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+            if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
+                break;
+            }
+
+            const int j = jc / ncols2;
+            const int c = jc % ncols2;
 
-            if (jt*ncols + j < ne01) {
+            if (jt*ncols1 + j < ne01) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
-                    tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                    const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
+                    tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
                 }
             } else {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
+                    tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
                 }
             }
         }
@@ -334,128 +513,217 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     __syncthreads();
 
     {
-        const int j0 = (threadIdx.y / np) * tile_B::I;
+        const int j0 = (threadIdx.y / np) * cols_per_warp;
 
 #pragma unroll
         for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
-            load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
+            if (ntiles == 1) {
+                load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
+                        tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
+                }
+            }
         }
     }
 
     __syncthreads();
 
-    // Preload K data for first iteration when using cp_async:
+    // Preload mask and K data for first iteration when using cp_async:
 #ifdef CP_ASYNC_AVAILABLE
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
 #endif // CP_ASYNC_AVAILABLE
 
     // Iterate over ne11 == previous tokens:
     for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
-        flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
-            (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+        flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
         constexpr bool last_iter = true;
-        flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
-            (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+        flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
     }
 
     // With cp_async there is no __syncthreads at the end of the iter,
     //     there can be a race condition on shared memory access for combining/writing back results.
 #ifdef CP_ASYNC_AVAILABLE
-    if (nwarps*tile_B::I > KQ_stride) {
+    if (nwarps*cols_per_warp > KQ_per_iter) {
         __syncthreads();
     }
 #endif // CP_ASYNC_AVAILABLE
 
     // Finally, sum up partial KQ rowsums.
-    // The partial sums are spread across 8 threads each, does not need full reduce.
+    // The partial sums are spread across 8/4 threads each, does not need full reduce.
+    {
+        constexpr int offset_first = ntiles == 1 ? 16 : 2;
+        constexpr int offset_last  = ntiles == 1 ?  4 : 1;
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
-    for (int offset = 16; offset > 2; offset >>= 1) {
-        KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
-        KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
+            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+            }
+        }
     }
 
     // Write VKQ accumulators to shared memory in column-major format.
     // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // Also for np > 1 the combination is done via these values in shared memory.
-    const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
+    if (ntiles == 1) {
+        const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
 #pragma unroll
-    for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
-        const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
 
 #pragma unroll
-        for (int l = 0; l < tile_B::ne; ++l) {
-            const int k = k0 + tile_B::get_j(l);
+            for (int l = 0; l < tile_B::ne; ++l) {
+                const int k = k0 + tile_B::get_j(l);
 
-            tile_K[j_cwd*D2_padded + k] = B.x[l];
+                tile_K[jc_cwd*D2_padded + k] = B.x[l];
+            }
+        }
+    } else {
+#pragma unroll
+        for (int t = 0; t < ntiles/2; ++t) {
+            const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+#pragma unroll
+            for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
+                    const int j = j0 + tile_C_VKQ_16::get_i(l);
+                    const int k = k0 + tile_C_VKQ_16::get_j(l);
+
+                    tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
+                }
+            }
         }
     }
 
-    const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
-    const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
-    const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
+    if constexpr (ntiles == 1) {
+        const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
+        const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+        const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
 
-    if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
-        // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
-        ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
-    }
+        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
 
-    __syncthreads();
+        __syncthreads();
 
-    static_assert(np == 1 || np == 2 || np == 4, "bad np");
-    if (np == 1) {
-        // No combination is needed, the meta data can be directly written from registers to VRAM.
-        if (needs_fixup && threadIdx.x < tile_B::I) {
-            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
-            dstk_fixup_meta[j_cwm] = KQ_cmr;
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
         }
-        if (is_fixup && threadIdx.x < tile_B::I) {
-            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
-            dstk_fixup_meta[j_cwm] = KQ_cmr;
+    } else {
+        static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
+        const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
+            + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
+            + tile_C_VKQ_16::get_i(threadIdx.x % 4);
+        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
+
+        __syncthreads();
+
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
         }
-    } else if (threadIdx.y % np == 0) {
+    }
+
+    static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
+    if (np > 1 && threadIdx.y % np == 0) {
         // Combine the meta data for parallel warps via shared memory.
         // Warps with threadIdx.y % np != 0 must NOT return early.
         // All threads must return simultaneously to avoid race conditions with work on the next tile.
 
-        float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
+        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
 
-        float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
-        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
-            KQ_cm = meta_j[0];
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
+        float2 meta[nmeta];
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
         }
 
-        float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
+        float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
+        }
 #pragma unroll
-        for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
             KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
         }
 
-        const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
-        float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
-        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
-            KQ_crs = KQ_cms*meta_j[1];
+        float KQ_cms[nmeta]; // KQ combine max scale per warp.
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
         }
+
+        float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
 #pragma unroll
-        for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_crs += KQ_cms[imeta]*meta[imeta].y;
+        }
+#pragma unroll
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
             KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
         }
 
         // Write back combined meta data:
-        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
-            *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+                // Combined KQ max scale + rowsum.
+                meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
+            }
         }
-        if (needs_fixup && threadIdx.x < tile_B::I) {
+
+        // Combined KQ max + rowsum.
+        static_assert(cols_per_warp <= WARP_SIZE);
+        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
-            dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
-        if (is_fixup && threadIdx.x < tile_B::I) {
+        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
-            dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
     }
 
@@ -470,27 +738,32 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
         for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-            const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-            const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
-            const int stride_j = WARP_SIZE / stride_k;
+            const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+            const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+            const int stride_jc = WARP_SIZE / stride_k;
 
             if (k0_start == k0_stop) {
                 continue;
             }
 
-            if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
-                break;
-            }
-
 #pragma unroll
-            for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
-                const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
-                const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
+            for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+                const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+                    break;
+                }
+
+                const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+                const int j_dst = jc_dst / ncols2;
+                const int c_dst = jc_dst % ncols2;
 
-                if (!is_fixup && jt*ncols + j_dst >= ne01) {
+                if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
                     continue;
                 }
-                const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
+
+                const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -498,8 +771,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     float2 dstk_val = make_float2(0.0f, 0.0f);
 #pragma unroll
                     for (int ip = 0; ip < np; ++ip) {
-                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
-                        const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
+                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
+                        const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
                         dstk_val.x += dstk_val_add.x*KQ_crs;
                         dstk_val.y += dstk_val_add.y*KQ_crs;
                     }
@@ -511,9 +784,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     }
 
                     if (is_fixup) {
-                        dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
+                        dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
                     } else {
-                        dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
+                        dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
                     }
                 }
             }
@@ -528,10 +801,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap>
 __launch_bounds__(nwarps*WARP_SIZE, 2)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -579,20 +850,23 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
+    static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
-    const int stride_Q    = nb01 / sizeof(float2);
+    const int stride_Q1   = nb01 / sizeof(float2);
+    const int stride_Q2   = nb02 / sizeof(float2);
     const int stride_KV   = nb11 / sizeof(half2);
-    const int stride_mask = nb31 / sizeof(half);
+    const int stride_mask = nb31 / sizeof(half2);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int iter_k = ne11 / KQ_stride;
-    const int iter_j = (ne01 + (ncols - 1)) / ncols;
+    constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
+    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -605,25 +879,28 @@ static __global__ void flash_attn_ext_f16(
         const int channel = kbc / (iter_k*iter_j);
         const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
 
-        const float2 * Q_f2  = (const float2 *) (Q + nb02* channel);
-        const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio));
-        const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
-        const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
-        float2       * dstk  = ((float2 *) dst) + channel*(D/2);
+        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+        const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
 
-        const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+        const int kb0_start_kernel = kb0_start * kb_niter;
+        const int kb0_stop_kernel  = kb0_stop  * kb_niter;
 
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
+            flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
-            flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
+            flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         }
 
         kbc += iter_k;
@@ -640,39 +917,46 @@ static __global__ void flash_attn_ext_f16(
     const int channel = kbc / (iter_k*iter_j);
     const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
 
-    const float2 * Q_f2  = (const float2 *) (Q + nb02* channel);
-    const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
-    const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
-    float2       * dstk  = ((float2 *) dst) + channel*(D/2);
+    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+    const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
 
-    const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
+    const int kb0_start_kernel = kb0_start * kb_niter;
+    const int kb0_stop_kernel  = kb0_stop  * kb_niter;
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
-        (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
+    flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+        (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+         ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 }
 
-template <int D, int cols_per_block>
+template <int D, int ncols1, int ncols2>
 void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    typedef tile<16, 8, half2> tile_A;
-    typedef tile< 8, 8, half2> tile_B;
+    constexpr int ncols         = ncols1 * ncols2;
+    constexpr int KQ_per_iter   = D <= 128 && ncols1 <= 64 ? 64 : 32;
+    constexpr int nwarps        = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
+    constexpr int ntiles        = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
+    constexpr int cols_per_warp = ntiles * tile_B::I;
 
-    static_assert(D              % tile_B::J == 0, "bad D");
-    static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
+    static_assert(D     %    tile_B::J  == 0, "bad D");
+    static_assert(ncols % cols_per_warp == 0, "bad ncols");
 
     const ggml_tensor * KQV = dst;
-    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int id    = ggml_cuda_get_device();
+    const int cc    = ggml_cuda_info().devices[id].cc;
+
+    const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
 
-    constexpr int KQ_stride = D <= 128 ? 64 : 32;
-    constexpr int nwarps    = (KQ_stride == 32 && cols_per_block <= 16) ?
-                              cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
+    const size_t nbytes_shared_KV      = KQ_shared_rows       * (D           + 8) * sizeof(half);
+    const size_t nbytes_shared_mask    = ncols1               * (KQ_per_iter + 8) * sizeof(half);
+    const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D           + 8) * sizeof(half);
 
-    const int    nrows_KQ      = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
-    const int    nrows_combine = nwarps*tile_B::J;
-    const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
+    const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -680,42 +964,58 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+
+    launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
 }
 
-#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block)                          \
+
+#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2)                          \
     template void ggml_cuda_flash_attn_ext_mma_f16_case                     \
-    <D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
-
-extern DECL_FATTN_MMA_F16_CASE( 64,  8);
-extern DECL_FATTN_MMA_F16_CASE( 80,  8);
-extern DECL_FATTN_MMA_F16_CASE( 96,  8);
-extern DECL_FATTN_MMA_F16_CASE(112,  8);
-extern DECL_FATTN_MMA_F16_CASE(128,  8);
-extern DECL_FATTN_MMA_F16_CASE(256,  8);
-
-extern DECL_FATTN_MMA_F16_CASE( 64, 16);
-extern DECL_FATTN_MMA_F16_CASE( 80, 16);
-extern DECL_FATTN_MMA_F16_CASE( 96, 16);
-extern DECL_FATTN_MMA_F16_CASE(112, 16);
-extern DECL_FATTN_MMA_F16_CASE(128, 16);
-extern DECL_FATTN_MMA_F16_CASE(256, 16);
-
-extern DECL_FATTN_MMA_F16_CASE( 64, 32);
-extern DECL_FATTN_MMA_F16_CASE( 80, 32);
-extern DECL_FATTN_MMA_F16_CASE( 96, 32);
-extern DECL_FATTN_MMA_F16_CASE(112, 32);
-extern DECL_FATTN_MMA_F16_CASE(128, 32);
-extern DECL_FATTN_MMA_F16_CASE(256, 32);
-
-extern DECL_FATTN_MMA_F16_CASE( 64, 64);
-extern DECL_FATTN_MMA_F16_CASE( 80, 64);
-extern DECL_FATTN_MMA_F16_CASE( 96, 64);
-extern DECL_FATTN_MMA_F16_CASE(112, 64);
-extern DECL_FATTN_MMA_F16_CASE(128, 64);
-extern DECL_FATTN_MMA_F16_CASE(256, 64);
+    <D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64);
+
+// Kernels with ncols == 128 are only 4% faster due to register pressure.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
index d4edbad07f26a2842299e3bcb8a5df480ef74dc9..b8b415effb7e181c02d2a204e3b5b7eb6d4c9c61 100644 (file)
@@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
index 0d274f33255b7ee03f4f7aaa82f49557fe4c565c..4352a284464764800f0528589a610a6012661238 100644 (file)
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
index d9ac4424606e40cfad933ebbb8082ad5900c8352..e758a0f6ec276eca1e6f831e09e499ee88e9ef99 100644 (file)
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
index 6ef8f9dcc27564b019dee207c25e4013012c7df2..134144a383ffaaa31bf5ba7acf3643d880d73d59 100644 (file)
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
index 45702ad651fe67e906d35832b86863220c81e191..de38470abec456ae3a23a91a13e90ff3ee1b8cd6 100644 (file)
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
         return;
     }
     constexpr int parallel_blocks = 1;
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index b0cf152f52cf1bb5449342b1a3c95ffe22e082c3..b1becccb4de72f432fbd7690be0e3c6171562298 100644 (file)
@@ -8,28 +8,50 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-template <int cols_per_block>
+template <int D, int ncols2>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    if (Q->ne[1] <= 8/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 16/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 32/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
+}
+
+template <int ncols2>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
 
     switch (Q->ne[0]) {
         case 64:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
             break;
         case 80:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
             break;
         case 96:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
             break;
         case 112:
-            ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
             break;
         case 128:
-            ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
             break;
         case 256:
-            ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
 }
 
 static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    const float use_gqa_opt = mask && max_bias == 0.0f;
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
 
-    if (Q->ne[1] <= 8) {
+    if (use_gqa_opt && gqa_ratio % 8 == 0) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 16) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio == 4) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 32) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio == 2) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
+    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
 }
 
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
+        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
index 0a5656e4cb3101630835e9c5716b84e46fefcd20..9206bfeba3de40d63bc25f3adc374ba8c546f570 100644 (file)
@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 8) {
                 return (l / 2) * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 16) {
+                return ((l / 2) % 2) * 8 + threadIdx.x / 4;
             } else {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
                 return 4 * l + threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
                 return 2 * (threadIdx.x % 4) + l % 2;
+            } else if constexpr (I == 16 && J == 16) {
+                return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
             } else {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
 #endif // NEW_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef NEW_MMA_AVAILABLE
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
 #endif // NEW_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
 }
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu
deleted file mode 100644 (file)
index f09bdef..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 16);
-DECL_FATTN_MMA_F16_CASE(80, 16);
-DECL_FATTN_MMA_F16_CASE(96, 16);
-DECL_FATTN_MMA_F16_CASE(112, 16);
-DECL_FATTN_MMA_F16_CASE(128, 16);
-DECL_FATTN_MMA_F16_CASE(256, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu
deleted file mode 100644 (file)
index 2211088..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 32);
-DECL_FATTN_MMA_F16_CASE(80, 32);
-DECL_FATTN_MMA_F16_CASE(96, 32);
-DECL_FATTN_MMA_F16_CASE(112, 32);
-DECL_FATTN_MMA_F16_CASE(128, 32);
-DECL_FATTN_MMA_F16_CASE(256, 32);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu
deleted file mode 100644 (file)
index d24b085..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 64);
-DECL_FATTN_MMA_F16_CASE(80, 64);
-DECL_FATTN_MMA_F16_CASE(96, 64);
-DECL_FATTN_MMA_F16_CASE(112, 64);
-DECL_FATTN_MMA_F16_CASE(128, 64);
-DECL_FATTN_MMA_F16_CASE(256, 64);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu
deleted file mode 100644 (file)
index bdf86c0..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 8);
-DECL_FATTN_MMA_F16_CASE(80, 8);
-DECL_FATTN_MMA_F16_CASE(96, 8);
-DECL_FATTN_MMA_F16_CASE(112, 8);
-DECL_FATTN_MMA_F16_CASE(128, 8);
-DECL_FATTN_MMA_F16_CASE(256, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
new file mode 100644 (file)
index 0000000..8010861
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 1, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
new file mode 100644 (file)
index 0000000..66161c0
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 16, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
new file mode 100644 (file)
index 0000000..ee88c72
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 16, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
new file mode 100644 (file)
index 0000000..d888a5a
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
new file mode 100644 (file)
index 0000000..d93a2d0
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
new file mode 100644 (file)
index 0000000..617464c
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 2, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
new file mode 100644 (file)
index 0000000..970d2b6
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 32, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
new file mode 100644 (file)
index 0000000..65cd377
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 32, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
new file mode 100644 (file)
index 0000000..f4a8bf3
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 4, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
new file mode 100644 (file)
index 0000000..de191a8
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
new file mode 100644 (file)
index 0000000..e8cb0e1
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 4, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
new file mode 100644 (file)
index 0000000..a532e96
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 64, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
new file mode 100644 (file)
index 0000000..bf25181
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 8, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
new file mode 100644 (file)
index 0000000..378c132
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 8, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
new file mode 100644 (file)
index 0000000..372641b
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 8, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
new file mode 100644 (file)
index 0000000..9ff5968
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 8, 8);
index a2628f16e57d1fe85a755649c42b32e3804acd69..dd373a09d26f3a19c09daad3d953c7ed3fd26f3e 100755 (executable)
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
 
 """
 
-SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
 
-for cols_per_block in [8, 16, 32, 64]:
-    with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
-        f.write(SOURCE_FATTN_MMA_START)
-
-        for head_size in [64, 80, 96, 112, 128, 256]:
-            f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
+for ncols in [8, 16, 32, 64, 128]:
+    for ncols2 in [1, 2, 4, 8]:
+        ncols1 = ncols // ncols2
+        if ncols == 128:
+            continue  # Too much register pressure.
+        with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
+            f.write(SOURCE_FATTN_MMA_START)
+
+            for head_size in [64, 80, 96, 112, 128, 256]:
+                if ncols == 128 and head_size == 256:
+                    continue  # Needs too much shared memory.
+                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
 
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
index c9ab6c135f325608fd7268631e48d778fe56640a..e1f7e6758b62efcc35ef451cecea3e5902b7b4f9 100644 (file)
@@ -3119,6 +3119,7 @@ struct test_leaky_relu : public test_case {
 struct test_flash_attn_ext : public test_case {
     const int64_t hs; // head size
     const int64_t nh; // num heads
+    const int64_t nr; // repeat in Q, tests for grouped-query attention
     const int64_t kv; // kv size
     const int64_t nb; // batch size
 
@@ -3131,7 +3132,7 @@ struct test_flash_attn_ext : public test_case {
     std::array<int32_t, 4> permute;
 
     std::string vars() override {
-        return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
+        return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -3142,13 +3143,13 @@ struct test_flash_attn_ext : public test_case {
         GGML_UNUSED(t);
         // Just counting matmul costs:
         // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
-        return 2 * 2 * nh * nb * hs * kv;
+        return 2 * 2 * nh*nr * nb * hs * kv;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
                         std::array<int32_t, 4> permute = {0, 1, 2, 3})
-        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
+        : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3166,13 +3167,13 @@ struct test_flash_attn_ext : public test_case {
             return t;
         };
 
-        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh,    1);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh,    1);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -4278,14 +4279,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                 if (!mask && max_bias > 0.0f) continue;
                 for (float logit_softcap : {0.0f, 10.0f}) {
                     if (hs != 128 && logit_softcap != 0.0f) continue;
-                    for (int nh : { 32, }) {
-                        for (int kv : { 512, 1024, }) {
-                            for (int nb : { 1, 3, 32, 35, }) {
-                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                    test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
-                                    // run fewer test cases permuted
-                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
-                                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+                    for (int nh : { 4, }) {
+                        for (int nr : { 1, 4, 16 }) {
+                            if (nr == 16 && hs != 128) continue;
+                            for (int kv : { 512, 1024, }) {
+                                if (nr != 1 && kv != 512) continue;
+                                for (int nb : { 1, 3, 32, 35, }) {
+                                    for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
+                                        // run fewer test cases permuted
+                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                            test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+                                        }
                                     }
                                 }
                             }