]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: faster tile FA, add oob checks, more HSs (llama/16492)
authorJohannes Gäßler <redacted>
Sat, 11 Oct 2025 18:54:32 +0000 (20:54 +0200)
committerGeorgi Gerganov <redacted>
Wed, 15 Oct 2025 06:29:17 +0000 (09:29 +0300)
18 files changed:
ggml/src/ggml-cuda/CMakeLists.txt
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-tile.cu
ggml/src/ggml-cuda/fattn-tile.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
ggml/src/ggml-hip/CMakeLists.txt
ggml/src/ggml-musa/CMakeLists.txt

index bdcefe7b7ed7a8e8ab4fdd5e85a609a3df7176d6..3024775135966133a9afe7857f687aa75640acf1 100644 (file)
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
     file(GLOB   GGML_SOURCES_CUDA "*.cu")
+    file(GLOB   SRCS "template-instances/fattn-tile*.cu")
+    list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/mmq*.cu")
index d51abbeafa944fb966f864095c30634e29e1be82..e0abde5427c832852f1803b70f1167986af4f05e 100644 (file)
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
 }
 
 static bool fast_fp16_available(const int cc) {
-    return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
+    return GGML_CUDA_CC_IS_AMD(cc) ||
+        (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
 }
 
 // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
+// Important: do not use this function if dst and src both point at registers.
+//     Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
+//     The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
+//     If dst and src point at different address spaces then they are guaranteed to not be aliased.
 template <int nbytes, int alignment = 0>
 static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
     if constexpr (alignment != 0) {
index 33d2f0f49e3de37489083ecc6bf1aa0ec6e5a839..bc0c2523cc82f7b790c603c67a5055f5eb435466 100644 (file)
@@ -793,8 +793,6 @@ void launch_fattn(
     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
         "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
-    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -878,7 +876,7 @@ void launch_fattn(
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
     //     multiple sequences of possibly different lengths.
-    if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
         const int s31 = mask->nb[1] / sizeof(half2);
         const int s33 = mask->nb[3] / sizeof(half2);
 
@@ -916,8 +914,7 @@ void launch_fattn(
 
         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
     } else {
-        GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
-        const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+        const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
 
         // parallel_blocks must not be larger than what the tensor size allows:
         parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
 
         blocks_num.x = ntiles_x;
         blocks_num.y = parallel_blocks;
-        blocks_num.z = Q->ne[2]*Q->ne[3];
+        blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
index 68de623d803499cbb602f4385105212b39e5496a..3a5806d9091d763d344817498d378d018de71baa 100644 (file)
 #include "common.cuh"
-#include "fattn-common.cuh"
 #include "fattn-tile.cuh"
 #include "fattn-wmma-f16.cuh"
 
-// kq_stride == number of KQ rows to process per iteration
-// kq_nbatch == number of K columns to load in parallel for KQ calculation
-
-static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
-    if (GGML_CUDA_CC_IS_AMD(cc)) {
-        if (GGML_CUDA_CC_IS_RDNA(cc)) {
-            switch (D) {
-                case 64:
-                    return 128;
-                case 128:
-                case 256:
-                    return ncols <= 16 ? 128 : 64;
-                default:
-                    GGML_ABORT("fatal error");
-                    return -1;
-            }
-        }
-        switch (D) {
-            case 64:
-                return ncols == 32 ? 128 : 64;
-            case 128:
-                return ncols == 32 ? 64 : 32;
-            case 256:
-                return 32;
-            default:
-                GGML_ABORT("fatal error");
-                return -1;
-        }
-    }
-    if (fast_fp16_available(cc)) {
-        switch (D) {
-            case 64:
-            case 128:
-            case 256:
-                return ncols <= 16 ? 128 : 64;
-            default:
-                GGML_ABORT("fatal error");
-                return -1;
-        }
-    }
-    switch (D) {
-        case 64:
-            return ncols <= 16 ? 128 : 64;
-        case 128:
-            return ncols <= 16 ? 64 : 32;
-        case 256:
-            return 32;
-        default:
-            GGML_ABORT("fatal error");
-            return -1;
-    }
-    GGML_UNUSED(warp_size);
-}
-
-static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
-#ifdef GGML_USE_HIP
-#ifdef RDNA
-    switch (D) {
-        case 64:
-            return 128;
-        case 128:
-        case 256:
-            return ncols <= 16 ? 128 : 64;
-        default:
-            return -1;
-    }
-#else
-    switch (D) {
-        case 64:
-            return ncols == 32 ? 128 : 64;
-        case 128:
-            return ncols == 32 ? 64 : 32;
-        case 256:
-            return 32;
-        default:
-            return -1;
-    }
-#endif // RDNA
-#else
-#ifdef FAST_FP16_AVAILABLE
-    switch (D) {
-        case 64:
-        case 128:
-        case 256:
-            return ncols <= 16 ? 128 : 64;
-        default:
-            return -1;
-    }
-#else
-    switch (D) {
-        case 64:
-            return ncols <= 16 ? 128 : 64;
-        case 128:
-            return ncols <= 16 ? 64 : 32;
-        case 256:
-            return 32;
-        default:
-            return -1;
-    }
-#endif // FAST_FP16_AVAILABLE
-#endif // GGML_USE_HIP
-    GGML_UNUSED_VARS(ncols, warp_size);
-}
-
-static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
-#ifdef GGML_USE_HIP
-    switch (D) {
-        case 64:
-            return 64;
-        case 128:
-        case 256:
-            return 128;
-        default:
-            return -1;
-    }
-#else
-#ifdef FAST_FP16_AVAILABLE
-    switch (D) {
-        case 64:
-            return 64;
-        case 128:
-        case 256:
-            return 128;
-        default:
-            return -1;
-    }
-#else
-    switch (D) {
-        case 64:
-            return 64;
-        case 128:
-            return 128;
-        case 256:
-            return ncols <= 16 ? 128 : 64;
-        default:
-            return -1;
-    }
-#endif // FAST_FP16_AVAILABLE
-#endif // GGML_USE_HIP
-    GGML_UNUSED_VARS(ncols, warp_size);
-}
-
-static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
-    return 256;
-    GGML_UNUSED_VARS(cc, ncols);
-}
-
-static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
-    return 256;
-    GGML_UNUSED(ncols);
-}
-
-static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
-#ifdef RDNA
-    return 3;
-#else
-    return ncols <= 16 ? 3 : 2;
-#endif // RDNA
-    GGML_UNUSED(ncols);
-}
-
-template<int D, int ncols, bool use_logit_softcap> // D == head size
-__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
-static __global__ void flash_attn_tile(
-        const char * __restrict__ Q,
-        const char * __restrict__ K,
-        const char * __restrict__ V,
-        const char * __restrict__ mask,
-        const char * __restrict__ sinks,
-        const int  * __restrict__ KV_max,
-        float      * __restrict__ dst,
-        float2     * __restrict__ dst_meta,
-        const float scale,
-        const float max_bias,
-        const float m0,
-        const float m1,
-        const uint32_t n_head_log2,
-        const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
-                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
-        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
-                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
-                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
-                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
-                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#ifdef FLASH_ATTN_AVAILABLE
-
-    // Skip unused kernel variants for faster compilation:
-#ifdef GGML_USE_WMMA_FATTN
-    NO_DEVICE_CODE;
-    return;
-#endif // GGML_USE_WMMA_FATTN
-
-    if (use_logit_softcap && !(D == 128 || D == 256)) {
-        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
-            max_bias, m0, m1, n_head_log2, logit_softcap,
-            ne00, ne01, ne02, ne03,
-                  nb01, nb02, nb03,
-            ne10, ne11, ne12, ne13,
-                  nb11, nb12, nb13,
-                  nb21, nb22, nb23,
-                  ne31, ne32, ne33,
-                  nb31, nb32, nb33);
-        NO_DEVICE_CODE;
-        return;
-    }
-
-    constexpr int warp_size = 32;
-    constexpr int nwarps    = fattn_tile_get_nthreads_device(ncols) / warp_size;
-    constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
-    static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
-    constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
-    static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
-
-    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
-
-    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
-
-    const int sequence = blockIdx.z / ne02;
-    const int head = blockIdx.z - sequence*ne02;
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
-    const half2 * K_h2   = (const half2 *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
-    const half2 * V_h2   = (const half2 *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
-    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
-    const float * sinksf = (const float *) (sinks);
-
-    const int stride_KV2 = nb11 / sizeof(half2);
-
-    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
-
-    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
-    constexpr int cpy_ne = cpy_nb / 4;
-
-    constexpr int cpw = ncols/nwarps; // cols per warp
-
-    // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
-    // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
-#ifdef FAST_FP16_AVAILABLE
-    constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
-
-    __shared__ half  KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
-    __shared__ half2 Q_tmp[ncols][D/2];
-    __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
-    half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
-#else
-    constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
-
-    __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
-    __shared__ float Q_tmp[ncols][D];
-    __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
-    float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
-#endif // FAST_FP16_AVAILABLE
-    static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
-
-    float KQ_max[cpw];
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
-    }
-    float KQ_sum[cpw] = {0.0f};
-
-    // Load Q data, convert to FP16 if fast.
-#pragma unroll
-    for (int j0 = 0; j0 < cpw; ++j0) {
-        const int j = j0 + threadIdx.y*cpw;
-
-        constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
-
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
-            float tmp_f[cpy_ne_D] = {0.0f};
-            if (ic0 + j < ne01) {
-                ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
-            }
-
-#pragma unroll
-            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
-                tmp_f[i1] *= scale;
-            }
-
-#ifdef FAST_FP16_AVAILABLE
-            half2 tmp_h2[cpy_ne_D/2];
-#pragma unroll
-            for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
-                tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
-            }
-            ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
-#else
-            ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0   + threadIdx.x* cpy_ne_D],    tmp_f);
-#endif // FAST_FP16_AVAILABLE
-        }
-    }
-
-    __syncthreads();
-
-    // Main loop over KV cache:
-    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
-    for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
-        // Calculate KQ tile and keep track of new maximum KQ values:
-
-        float KQ_max_new[cpw];
-#pragma unroll
-        for (int j = 0; j < cpw; ++j) {
-            KQ_max_new[j] = KQ_max[j];
-        }
-
-        float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
-
-        // KQ = K @ Q matrix multiplication:
-#pragma unroll
-        for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
-#pragma unroll
-            for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
-                const int i_KQ = i_KQ_0 + threadIdx.y;
-
-#ifdef FAST_FP16_AVAILABLE
-                constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
-#pragma unroll
-                for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
-                    ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
-                        &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
-                        &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
-                }
-#else
-                constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
-#pragma unroll
-                for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
-                    half2 tmp_h2[cpy_ne_kqnb/2];
-                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
-                        tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
-
-                    float2 tmp_f2[cpy_ne_kqnb/2];
-#pragma unroll
-                    for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
-                        tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
-                    }
-                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
-                        &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
-                }
-#endif // FAST_FP16_AVAILABLE
-            }
-
-            __syncthreads();
-
-#ifdef FAST_FP16_AVAILABLE
-#pragma unroll
-            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
-                half2 K_k[kq_stride/warp_size][cpy_ne];
-                half2 Q_k[cpw][cpy_ne];
-#else
-#pragma unroll
-            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
-                float K_k[kq_stride/warp_size][cpy_ne];
-                float Q_k[cpw][cpy_ne];
-#endif // FAST_FP16_AVAILABLE
-
-#pragma unroll
-                for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
-                    const int i_KQ = i_KQ_0 + threadIdx.x;
-
-#ifdef FAST_FP16_AVAILABLE
-                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
-#else
-                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch   + cpy_ne) + k_KQ_1]);
-#endif // FAST_FP16_AVAILABLE
-                }
-#pragma unroll
-                for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
-                    const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
-
-#ifdef FAST_FP16_AVAILABLE
-                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
-#else
-                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0   + k_KQ_1]);
-#endif // FAST_FP16_AVAILABLE
-                }
-
-#pragma unroll
-                for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
-#pragma unroll
-                    for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
-#pragma unroll
-                        for (int k = 0; k < cpy_ne; ++k) {
-                            ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
-                        }
-                    }
-                }
-            }
-
-            if (k_KQ_0 + kq_nbatch < D) {
-                __syncthreads(); // Sync not needed on last iteration.
-            }
-        }
-
-        // Apply logit softcap, mask, update KQ_max:
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
-            const int i_KQ = i_KQ_0 + threadIdx.x;
-
-#pragma unroll
-            for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
-                const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
-
-                if (use_logit_softcap) {
-                    KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
-                }
-
-                KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
-
-                KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
-#pragma unroll
-        for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
-#ifdef FAST_FP16_AVAILABLE
-            half  tmp[kq_stride/warp_size][softmax_iter_j];
-#else
-            float tmp[kq_stride/warp_size][softmax_iter_j];
-#endif // FAST_FP16_AVAILABLE
-
-#pragma unroll
-            for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
-                KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
-                const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
-                KQ_max[j0+j1] = KQ_max_new[j0+j1];
-
-                float KQ_sum_add = 0.0f;
-#pragma unroll
-                for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
-                    const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
-                    KQ_sum_add += val;
-                    tmp[i0/warp_size][j1] = val;
-                }
-                KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
-
-#ifdef FAST_FP16_AVAILABLE
-                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                    VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
-                }
-#else
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                    VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
-                    VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
-                }
-#endif // FAST_FP16_AVAILABLE
-            }
-
-#pragma unroll
-            for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
-                const int i = i0 + threadIdx.x;
-
-                ggml_cuda_memcpy_1<sizeof(tmp[0])>(
-                    KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
-            }
-        }
-
-        // VKQ = V @ KQ matrix multiplication:
-        constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
-        static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
-#pragma unroll
-        for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
-#pragma unroll
-            for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
-                const int k_tile = k1 + threadIdx.y;
-
-#ifdef FAST_FP16_AVAILABLE
-                constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
-                    ggml_cuda_memcpy_1<cpy_ne_D*4>(
-                        &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
-                        &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
-                }
-#else
-                constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
-#pragma unroll
-                for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
-                    half2 tmp_h2[cpy_ne_D/2];
-                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
-                        tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
-
-                    float2 tmp_f2[cpy_ne_D/2];
-#pragma unroll
-                    for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
-                        tmp_f2[i1] = __half22float2(tmp_h2[i1]);
-                    }
-                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
-                        &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
-                }
-#endif // FAST_FP16_AVAILABLE
-            }
-
-            __syncthreads();
-
-#ifdef FAST_FP16_AVAILABLE
-#pragma unroll
-            for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
-                half2 V_k[(D/2)/warp_size];
-                half2 KQ_k[cpw];
-
-                constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
-                    ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
-                }
-#pragma unroll
-                for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
-                    const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
-
-                    half tmp[softmax_iter_j];
-                    ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
-                        &tmp, KQ[j][k0 + k1]);
-#pragma unroll
-                    for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
-                        KQ_k[j0+j1] = __half2half2(tmp[j1]);
-                    }
-                }
-
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-#pragma unroll
-                    for (int j0 = 0; j0 < cpw; ++j0) {
-                        VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
-                    }
-                }
-            }
-#else
-#pragma unroll
-            for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
-                float2 V_k[(D/2)/warp_size];
-                float  KQ_k[cpw];
-
-                constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
-#pragma unroll
-                for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
-                    ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
-                }
-#pragma unroll
-                for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
-                    const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
-
-                    ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
-                        &KQ_k[j0], KQ[j][k0 + k1]);
-                }
-
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-#pragma unroll
-                    for (int j0 = 0; j0 < cpw; ++j0) {
-                        VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
-                        VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
-                    }
-                }
-            }
-#endif // FAST_FP16_AVAILABLE
-
-            __syncthreads();
-        }
-    }
-
-
-    // Attention sink: adjust running max and sum once per head
-    if (sinksf && blockIdx.y == 0) {
-        const float sink = sinksf[head];
-
-#pragma unroll
-        for (int j0 = 0; j0 < cpw; ++j0) {
-            float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
-            KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
-
-            const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
-            KQ_max[j0] = KQ_max_new_j;
-
-            const float val = expf(sink - KQ_max[j0]);
-            KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
-            if (threadIdx.x == 0) {
-                KQ_sum[j0] += val;
-            }
-
-#ifdef FAST_FP16_AVAILABLE
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
-            }
-#else
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0][i0/warp_size].x *= KQ_max_scale;
-                VKQ[j0][i0/warp_size].y *= KQ_max_scale;
-            }
-#endif // FAST_FP16_AVAILABLE
-        }
-    }
-
-#pragma unroll
-    for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
-        KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
-    }
-    if (gridDim.y == 1) {
-#pragma unroll
-        for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
-#ifdef FAST_FP16_AVAILABLE
-            const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
-#pragma unroll
-            for (int i = 0; i < (D/2)/warp_size; ++i) {
-                VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
-            }
-#else
-            const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
-#pragma unroll
-            for (int i = 0; i < (D/2)/warp_size; ++i) {
-                VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
-                VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
-            }
-#endif // FAST_FP16_AVAILABLE
-        }
-    }
-
-    // Write back results:
-#pragma unroll
-    for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
-        const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
-
-        if (ic0 + j_VKQ >= ne01) {
-            return;
-        }
-
-        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
-
-#ifdef FAST_FP16_AVAILABLE
-        constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
-#pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
-            float2 tmp[cpy_ne_D];
-#pragma unroll
-            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
-                tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
-            }
-            ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
-        }
-#else
-        constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
-            ggml_cuda_memcpy_1<cpy_ne_D*4>(
-                &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
-        }
-#endif // FAST_FP16_AVAILABLE
-
-        if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
-        }
-    }
-#else
-    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
-        max_bias, m0, m1, n_head_log2, logit_softcap,
-        ne00, ne01, ne02, ne03,
-              nb01, nb02, nb03,
-        ne10, ne11, ne12, ne13,
-              nb11, nb12, nb13,
-              nb21, nb22, nb23,
-              ne31, ne32, ne33,
-              nb31, nb32, nb33);
-    NO_DEVICE_CODE;
-#endif // FLASH_ATTN_AVAILABLE
-}
-
-template <int D, bool use_logit_softcap>
-static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-
-    const int id        = ggml_cuda_get_device();
-    const int cc        = ggml_cuda_info().devices[id].cc;
-    const int warp_size = 32;
-
-    constexpr size_t nbytes_shared = 0;
-
-#ifdef GGML_USE_HIP
-    if constexpr (D <= 128) {
-        if (Q->ne[1] > 32) {
-            constexpr int cols_per_block = 64;
-            const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
-            fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
-            const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
-            launch_fattn<D, cols_per_block, 1>
-                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
-            return;
-        }
-    }
-#endif // GGML_USE_HIP
-
-    if (Q->ne[1] > 16) {
-        constexpr int cols_per_block = 32;
-        const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
-        fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
-        const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
-        launch_fattn<D, cols_per_block, 1>
-            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
-        return;
-    }
-
-    constexpr int cols_per_block = 16;
-    const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
-    fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
-    const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
-    launch_fattn<D, cols_per_block, 1>
-        (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
-}
-
-template <bool use_logit_softcap>
-static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-    switch (Q->ne[0]) {
+void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+    switch (K->ne[0]) {
+        case  40: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case< 40,  40>(ctx, dst);
+        } break;
         case  64: {
-            launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case< 64,  64>(ctx, dst);
+        } break;
+        case  80: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case< 80,  80>(ctx, dst);
+        } break;
+        case  96: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case< 96,  96>(ctx, dst);
+        } break;
+        case 112: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
         } break;
         case 128: {
-            launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
         } break;
         case 256: {
-            launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
+        } break;
+        case 576: {
+            GGML_ASSERT(V->ne[0] == 512);
+            ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
         } break;
         default: {
             GGML_ABORT("Unsupported head size");
         } break;
     }
 }
-
-void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-
-    float logit_softcap;
-    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
-
-    if (logit_softcap == 0.0f) {
-        constexpr bool use_logit_softcap = false;
-        launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
-    } else {
-        constexpr bool use_logit_softcap = true;
-        launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
-    }
-}
index 10dc22d1bf9711ca845a0525daf423f80c4c5a1a..2efc9cc880cf8fa3608870be37dd35dd1fc52dda 100644 (file)
 #include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+// nbatch_fa == number of KQ rows to process per iteration
+// nbatch_K == number of K columns to load in parallel for KQ calculation
+
+// TODO optimize kernel parameters for FP16 NVIDIA (P100)
+// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
+
+// The ROCm compiler cannot handle templating in __launch_bounds__.
+// As a workaround, define a macro to package the kernel parameters as uint32_t:
+#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
+    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \
+        static_assert((nthreads)          <= 512, "bad nthreads");                                    \
+        static_assert((occupancy)         <=   8, "bad occupancy");                                   \
+        static_assert((nbatch_fa)         <= 256, "bad nbatch_fa");                                   \
+        static_assert((nbatch_K)          <= 256, "bad nbatch_K");                                    \
+        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \
+    }                                                                                                 \
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+
+    return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
+
+    return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
+
+    return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
+
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
+
+    return 0;
+}
+
+static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+    if (GGML_CUDA_CC_IS_AMD(cc)) {
+        if (GGML_CUDA_CC_IS_RDNA(cc)) {
+            return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
+        }
+        return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
+    }
+    if (fast_fp16_available(cc)) {
+        return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
+    }
+    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
+}
+
+static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
+#ifdef GGML_USE_HIP
+#ifdef RDNA
+    return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
+#else
+    return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
+#endif // RDNA
+#else
+#ifdef FAST_FP16_AVAILABLE
+    return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
+#else
+    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
+#endif // FAST_FP16_AVAILABLE
+#endif // GGML_USE_HIP
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
+    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
+}
+
+// TODO: deduplicate with mma-f16
+template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __device__ __forceinline__ void flash_attn_tile_load_tile(
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] __device__ (const int n) {
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
+
+                    const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+                    ggml_cuda_memcpy_1<cpy_nb>(
+                        tile_KV + i*(J/2 + J_padding) + j,
+                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+                }
+            }
+        }
+    };
+    // 1: max 64*16=512 bytes, 512 half
+    // 2: max 32*16=512 bytes, 256 half
+    // 3: max 16*16=256 bytes, 128 half
+    // 4: max  8*16=128 bytes,  64 half
+    // 5: max  4*16= 64 bytes,  32 half
+    // 6: max  2*16= 32 bytes,  16 half
+    // 7: max  1*16= 16 bytes,   8 half
+    static_assert(J % 8 == 0, "bad J");
+    static_assert((J/2) % cpy_ne == 0, "bad J");
+    ggml_cuda_unroll<7>{}(load);
+}
+
+template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __device__ __forceinline__ void flash_attn_tile_load_tile(
+        const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] __device__ (const int n) {
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
+
+                    const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
+                    half2 tmp_h2[cpy_ne/2];
+                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+
+                    float2 tmp_f2[cpy_ne/2];
+#pragma unroll
+                    for (int l = 0; l < cpy_ne/2; ++l) {
+                        tmp_f2[l] = __half22float2(tmp_h2[l]);
+                    }
+                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
+                }
+            }
+        }
+    };
+    // 1: max 32*16=512 bytes, 128 float
+    // 2: max 16*16=256 bytes,  64 float
+    // 3: max  8*16=128 bytes,  32 float
+    // 4: max  4*16= 64 bytes,  16 float
+    // 5: max  2*16= 32 bytes,   8 float
+    static_assert(J % 8 == 0, "bad J");
+    static_assert(J % cpy_ne == 0, "bad J");
+    ggml_cuda_unroll<5>{}(load);
+}
+
+// Function that performs a single iteration in for the KQ matrix multiplication:
+template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
+    bool use_logit_softcap, bool oob_check, typename T_vec_dot>
+static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
+        T_vec_dot   * const Q_tmp,
+        const half2 * const __restrict__ K_h2,
+        T_vec_dot   * const KV_tmp,
+        const int stride_K2,
+        const int k_VKQ_0,
+        const int k_VKQ_sup,
+        const int k_KQ_0,
+        float * KQ_acc) {
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
+        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+    __syncthreads();
+
+#ifdef FAST_FP16_AVAILABLE
+    static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
+        half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        half2 Q_k[cpw][cpy_ne];
+#else
+    static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
+        float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        float Q_k[cpw][cpy_ne];
+#endif // FAST_FP16_AVAILABLE
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+#ifdef FAST_FP16_AVAILABLE
+            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
+#else
+            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);
+#endif // FAST_FP16_AVAILABLE
+        }
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int jc = jc0 + (threadIdx.y / np)*cpw;
+
+#ifdef FAST_FP16_AVAILABLE
+            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
+#else
+            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);
+#endif // FAST_FP16_AVAILABLE
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+#pragma unroll
+            for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#pragma unroll
+                for (int k = 0; k < cpy_ne; ++k) {
+                    ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
+                }
+            }
+        }
+    }
+
+    if (k_KQ_0 + nbatch_K < DKQ) {
+        __syncthreads(); // Sync not needed on last iteration.
+    }
+}
+
+// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
+template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
+    bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
+static __device__ __forceinline__ void flash_attn_tile_iter(
+        T_vec_dot * const Q_tmp,
+        const half2 * const __restrict__ K_h2,
+        const half2 * const __restrict__ V_h2,
+        const half  * const __restrict__ mask,
+        const float logit_softcap,
+        const float slope,
+        T_KQ      * const KQ,
+        T_vec_dot * const KV_tmp,
+        const int stride_K2,
+        const int stride_V2,
+        const int stride_mask,
+        float * const KQ_max,
+        float * const KQ_sum,
+        T_acc * const VKQ,
+        const int k_VKQ_0,
+        const int k_VKQ_max) {
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+    // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
+    // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
+#ifdef FAST_FP16_AVAILABLE
+    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+#else
+    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+#endif // FAST_FP16_AVAILABLE
+    static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
+    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
+
+    float KQ_max_new[cpw];
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_max_new[jc0] = KQ_max[jc0];
+    }
+
+    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
+
+    // KQ = K @ Q matrix multiplication:
+    constexpr int nbatch_K_last = DKQ % nbatch_K;
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
+        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+    if (nbatch_K_last > 0) {
+        constexpr int k_KQ_0 = DKQ - nbatch_K_last;
+        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+
+    // Apply logit softcap + mask, update KQ_max:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+            if (use_logit_softcap) {
+                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+            }
+
+            KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
+                slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
+
+            KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+        }
+
+        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
+    }
+
+    if constexpr (np == 1) {
+        __syncthreads();
+    } else {
+        static_assert(cpw == 1, "bad cpw");
+        __shared__ float KQ_max_new_shared[nwarps];
+        if (threadIdx.x == 0) {
+            KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
+        }
+        __syncthreads();
+        KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
+        KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
+    }
+
+    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
+#ifdef FAST_FP16_AVAILABLE
+        half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#else
+        float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#endif // FAST_FP16_AVAILABLE
+
+#pragma unroll
+        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
+            const int jc = jc0 + jc1;
+
+            const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
+            KQ_max[jc] = KQ_max_new[jc];
+
+            float KQ_sum_add = 0.0f;
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+                const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
+                if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
+                    KQ_sum_add += val;
+                }
+                tmp[i0/(np*warp_size)][jc1] = val;
+            }
+            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
+
+#ifdef FAST_FP16_AVAILABLE
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
+            }
+#endif // FAST_FP16_AVAILABLE
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+            const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+            ggml_cuda_memcpy_1<sizeof(tmp[0])>(
+                KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
+                tmp[i0/(np*warp_size)]);
+        }
+    }
+
+    // VKQ = V @ KQ matrix multiplication:
+    static_assert(DV <= DKQ, "bad DV");
+    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
+    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
+    static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
+    static_assert(nbatch_V % np == 0, "bad nbatch_V");
+#pragma unroll
+    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+        flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
+            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+        __syncthreads();
+
+#ifdef FAST_FP16_AVAILABLE
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            half2 V_k[(DVp/2)/warp_size];
+            half2 KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
+
+                half tmp[KQ_cs];
+                ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
+                    &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
+#pragma unroll
+                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
+                    KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
+                }
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
+                }
+            }
+        }
+#else
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            float2 V_k[(DVp/2)/warp_size];
+            float  KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
+
+                ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
+                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
+                }
+            }
+        }
+#endif // FAST_FP16_AVAILABLE
+
+        __syncthreads();
+    }
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
+__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
+static __global__ void flash_attn_tile(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#ifdef FLASH_ATTN_AVAILABLE
+
+    // Skip unused kernel variants for faster compilation:
+
+    if (
+#ifdef GGML_USE_WMMA_FATTN
+            (ncols2 != 1 && DV != 40 && DV != 512) ||
+#endif // GGML_USE_WMMA_FATTN
+            (use_logit_softcap && !(DV == 128 || DV == 256))
+    ) {
+        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+            max_bias, m0, m1, n_head_log2, logit_softcap,
+            ne00, ne01, ne02, ne03,
+                  nb01, nb02, nb03,
+            ne10, ne11, ne12, ne13,
+                  nb11, nb12, nb13,
+                  nb21, nb22, nb23,
+                  ne31, ne32, ne33,
+                  nb31, nb32, nb33);
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
+
+    constexpr int ncols     = ncols1*ncols2;
+    constexpr int warp_size = 32;
+    constexpr int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
+    constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
+    constexpr int nbatch_K  = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
+
+    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
+
+    const int sequence = blockIdx.z / (ne02/ncols2);
+    const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0              + nb01*col_Q_0);
+    const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+    const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
+
+    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
+
+    const int stride_K2   = nb11 / sizeof(half2);
+    const int stride_V2   = nb21 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
+    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
+    static_assert(cpw == 1 || np == 1, "bad cpw / np");
+    static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
+
+    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
+    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.
+
+    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
+    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
+    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
+    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
+    // VKQ == Accumulators in registers for the final VKQ result.
+#ifdef FAST_FP16_AVAILABLE
+    __shared__ half2 Q_tmp[ncols * DKQ/2];
+    __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
+    __shared__ half  KQ[ncols * nbatch_fa];
+    half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+#else
+    __shared__ float Q_tmp[ncols * DKQ];
+    __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
+    __shared__ float KQ[ncols * nbatch_fa];
+    float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+#endif // FAST_FP16_AVAILABLE
+
+    float KQ_max[cpw];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
+    }
+    float KQ_sum[cpw] = {0.0f};
+
+    // Load Q data, convert to FP16 if fast:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (threadIdx.y / np)*cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
+
+#pragma unroll
+        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
+            if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
+                float tmp_f[cpy_ne_D] = {0.0f};
+                if (ncols1 == 1 || col_Q_0 + j < ne01) {
+                    ggml_cuda_memcpy_1<sizeof(tmp_f)>
+                        (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
+                                     + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
+                }
+
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    tmp_f[i1] *= scale;
+                }
+
+#ifdef FAST_FP16_AVAILABLE
+                half2 tmp_h2[cpy_ne_D/2];
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+                }
+                ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+                    &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
+                    tmp_h2);
+#else
+                ggml_cuda_memcpy_1<sizeof(tmp_f)>(
+                    &Q_tmp[jc* DKQ    + i0   + (threadIdx.y % np)*(warp_size*cpy_ne_D)   + threadIdx.x* cpy_ne_D],
+                    tmp_f);
+#endif // FAST_FP16_AVAILABLE
+            }
+        }
+    }
+
+    __syncthreads();
+
+    // Main loop over KV cache:
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    if (ncols2 == 1) {
+        // Branch with out-of-bounds checks.
+        int k_VKQ_0 = blockIdx.y*nbatch_fa;
+        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+            k_VKQ_0 += gridDim.y*nbatch_fa;
+        }
+        if (k_VKQ_0 < k_VKQ_max) {
+            constexpr bool oob_check = true;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+        }
+    } else {
+        // Branch without out-of-bounds checks.
+        for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+        }
+    }
+
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
+    }
+
+    if constexpr (np > 1) {
+        static_assert(cpw == 1, "bad cpw");
+        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
+
+#ifdef FAST_FP16_AVAILABLE
+        half2 * VKQ_combine    = (half2 *) KV_tmp;
+#else
+        float * VKQ_combine    = (float *) KV_tmp;
+#endif // FAST_FP16_AVAILABLE
+        float * KQ_sum_combine = (float *) Q_tmp;
+
+        if (threadIdx.y % np != 0) {
+#ifdef FAST_FP16_AVAILABLE
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
+            }
+#else
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(
+                    &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
+            }
+#endif // FAST_FP16_AVAILABLE
+
+            if (threadIdx.x == 0) {
+                KQ_sum_combine[threadIdx.y] = KQ_sum[0];
+            }
+
+            return;
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int ip = 1; ip < np; ++ip) {
+#ifdef FAST_FP16_AVAILABLE
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                half2 tmp[cpy_ne_D];
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    VKQ[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#else
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                float tmp[cpy_ne_D];
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#endif // FAST_FP16_AVAILABLE
+
+            KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
+        }
+    }
+
+    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
+    if (sinks && blockIdx.y == 0) {
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int jc = jc0 + (threadIdx.y/np)*cpw;
+            const float sink = ((const float *) sinks)[head0 + jc % ncols2];
+
+            float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
+            const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
+            KQ_max[jc0] = KQ_max_new_j;
+
+            const float val = expf(sink - KQ_max[jc0]);
+            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
+
+#ifdef FAST_FP16_AVAILABLE
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
+            }
+#endif // FAST_FP16_AVAILABLE
+        }
+    }
+
+    if (gridDim.y == 1) {
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#ifdef FAST_FP16_AVAILABLE
+            const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
+#pragma unroll
+            for (int i = 0; i < (DVp/2)/warp_size; ++i) {
+                VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
+            }
+#else
+            const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
+#pragma unroll
+            for (int i = 0; i < (DVp/2)/warp_size; ++i) {
+                VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
+                VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
+            }
+#endif // FAST_FP16_AVAILABLE
+        }
+    }
+
+    // Write back results:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (threadIdx.y/np)*cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        if (ncols1 > 1 && col_Q_0 + j >= ne01) {
+            return;
+        }
+
+        const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
+
+#ifdef FAST_FP16_AVAILABLE
+        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+            float2 tmp[cpy_ne_D];
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
+            }
+            if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
+                ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
+            }
+        }
+#else
+        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+            if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
+                ggml_cuda_memcpy_1<cpy_ne_D*4>(
+                    &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
+                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+            }
+        }
+#endif // FAST_FP16_AVAILABLE
+
+        if (gridDim.y != 1 && threadIdx.x == 0) {
+            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
+        }
+    }
+#else
+    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+        max_bias, m0, m1, n_head_log2, logit_softcap,
+        ne00, ne01, ne02, ne03,
+              nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+              nb11, nb12, nb13,
+              nb21, nb22, nb23,
+              ne31, ne32, ne33,
+              nb31, nb32, nb33);
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
+}
+
+template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    const int id        = ggml_cuda_get_device();
+    const int cc        = ggml_cuda_info().devices[id].cc;
+    const int warp_size = 32;
+
+    constexpr size_t nbytes_shared = 0;
+
+#ifdef GGML_USE_HIP
+    if constexpr (DV <= 128) {
+        if (Q->ne[1] > 32/ncols2) {
+            constexpr int cols_per_block = 64;
+            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+            launch_fattn<DV, cols_per_block/ncols2, ncols2>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+            return;
+        }
+    }
+#endif // GGML_USE_HIP
+
+#ifndef GGML_USE_HIP
+    if constexpr (DV <= 256)
+#endif // GGML_USE_HIP
+    {
+        if (Q->ne[1] > 16/ncols2) {
+            constexpr int cols_per_block = 32;
+            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+            launch_fattn<DV, cols_per_block/ncols2, ncols2>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+            return;
+        }
+    }
+
+    if (Q->ne[1] > 8/ncols2) {
+        constexpr int cols_per_block = 16;
+        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+        launch_fattn<DV, cols_per_block/ncols2, ncols2>
+            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+        return;
+    }
+
+    if constexpr (ncols2 <= 8) {
+        if (Q->ne[1] > 4/ncols2) {
+            constexpr int cols_per_block = 8;
+            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+            launch_fattn<DV, cols_per_block/ncols2, ncols2>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 4) {
+        if (Q->ne[1] > 2/ncols2) {
+            constexpr int cols_per_block = 4;
+            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+            launch_fattn<DV, cols_per_block/ncols2, ncols2>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 2) {
+        constexpr int cols_per_block = 2;
+        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+        launch_fattn<DV, cols_per_block/ncols2, ncols2>
+            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+        return;
+    }
+
+    GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+    const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
+    const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
+    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    if constexpr (DV == 512) {
+        if (use_gqa_opt && gqa_ratio % 16 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
+            return;
+        }
+    }
+
+    if constexpr (DV <= 256) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
+        return;
+    }
+    GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV>
+void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+    }
+}
 
 void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \
+    template void ggml_cuda_flash_attn_ext_tile_case              \
+    <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_TILE_CASE( 40,  40);
+extern DECL_FATTN_TILE_CASE( 64,  64);
+extern DECL_FATTN_TILE_CASE( 80,  80);
+extern DECL_FATTN_TILE_CASE( 96,  96);
+extern DECL_FATTN_TILE_CASE(112, 112);
+extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(576, 512);
index 1848d088361850436a682d46fe357eed02b7dd58..7235f1b77aeedcc426a8793a8f456461be28bc3d 100644 (file)
@@ -1,3 +1,5 @@
+#pragma once
+
 #include "common.cuh"
 
 #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
index 0c8e7b3e41904e3e89ee6fbf66b9f1088265d86d..fe970adaecef3a523cd6409fc4da36f3389a77c9 100644 (file)
@@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     return BEST_FATTN_KERNEL_NONE;
 #endif// FLASH_ATTN_AVAILABLE
 
+    const ggml_tensor * KQV   = dst;
     const ggml_tensor * Q     = dst->src[0];
     const ggml_tensor * K     = dst->src[1];
     const ggml_tensor * V     = dst->src[2];
@@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     const int gqa_ratio = Q->ne[2] / K->ne[2];
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
 
-    const int cc = ggml_cuda_info().devices[device].cc;
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    // TODO: temporary until support is extended
-    //       https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
-    if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
-        return BEST_FATTN_KERNEL_NONE;
-    }
+    // The effective batch size for the kernel can be increased by gqa_ratio.
+    // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
+    const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    const int cc = ggml_cuda_info().devices[device].cc;
 
     switch (K->ne[0]) {
+        case  40:
         case  64:
-        case 128:
-        case 256:
-            if (V->ne[0] != K->ne[0]) {
-                return BEST_FATTN_KERNEL_NONE;
-            }
-            break;
         case  80:
         case  96:
+        case 128:
         case 112:
+        case 256:
             if (V->ne[0] != K->ne[0]) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
-                return BEST_FATTN_KERNEL_NONE;
-            }
             break;
         case 576:
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
+            if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
@@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         return BEST_FATTN_KERNEL_NONE;
     }
 
-    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
-
-    // If Turing tensor cores available, use them except for some cases with batch size 1:
-    if (turing_mma_available(cc)) {
-        best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
+    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
+    // If Turing tensor cores available, use them:
+    if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
         if (can_use_vector_kernel) {
             if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
-                    best = BEST_FATTN_KERNEL_VEC;
+                    return BEST_FATTN_KERNEL_VEC;
                 }
             } else {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
                     if (Q->ne[1] <= 2) {
-                        best = BEST_FATTN_KERNEL_VEC;
+                        return BEST_FATTN_KERNEL_VEC;
                     }
                 } else {
                     if (Q->ne[1] == 1) {
-                        best = BEST_FATTN_KERNEL_VEC;
+                        return BEST_FATTN_KERNEL_VEC;
                     }
                 }
             }
-            if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
-                best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
+            if (!gqa_opt_applies && Q->ne[1] == 1) {
+                return BEST_FATTN_KERNEL_VEC;
             }
         }
 
-        return best;
-    }
-
-    // Use kernels specialized for small batch sizes if possible:
-    if (Q->ne[1] <= 8 && can_use_vector_kernel) {
-        return BEST_FATTN_KERNEL_VEC;
+        return BEST_FATTN_KERNEL_MMA_F16;
     }
 
-    // For large batch sizes, use the WMMA kernel if possible:
-    if (ggml_cuda_should_use_wmma_fattn(cc)) {
+    // Use the WMMA kernel if possible:
+    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
+        if (can_use_vector_kernel && Q->ne[1] <= 2) {
+            return BEST_FATTN_KERNEL_VEC;
+        }
         return BEST_FATTN_KERNEL_WMMA_F16;
     }
 
-    // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
+    // If there are no tensor cores available, use the generic tile kernel:
+    if (can_use_vector_kernel) {
+        if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
+            if (Q->ne[1] == 1) {
+                if (!gqa_opt_applies) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        } else {
+            if (Q->ne[1] <= 2) {
+                return BEST_FATTN_KERNEL_VEC;
+            }
+        }
+    }
     return BEST_FATTN_KERNEL_TILE;
 }
 
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu
new file mode 100644 (file)
index 0000000..a8b15ad
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(112, 112);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu
new file mode 100644 (file)
index 0000000..1da1810
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(128, 128);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu
new file mode 100644 (file)
index 0000000..bc65c72
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(256, 256);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu
new file mode 100644 (file)
index 0000000..10b330f
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(40, 40);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu
new file mode 100644 (file)
index 0000000..254b7d2
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(576, 512);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu
new file mode 100644 (file)
index 0000000..5caffac
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(64, 64);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu
new file mode 100644 (file)
index 0000000..90abb3b
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(80, 80);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu
new file mode 100644 (file)
index 0000000..7292c0a
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(96, 96);
index d410080fab841042497d584d91ffa051bde4e984..81a986f38cacff056f6cd5e432e7565e923f69f7 100755 (executable)
@@ -3,8 +3,17 @@
 from glob import glob
 import os
 
+HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
+
 TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
 
+SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
+"""
+
 SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec.cuh"
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
 for filename in glob("*.cu"):
     os.remove(filename)
 
+for head_size_kq in HEAD_SIZES_KQ:
+    head_size_v = head_size_kq if head_size_kq != 576 else 512
+    with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
+        f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
+
 for type_k in TYPES_KV:
     for type_v in TYPES_KV:
         with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
         with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
             f.write(SOURCE_FATTN_MMA_START)
 
-            for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
+            for head_size_kq in HEAD_SIZES_KQ:
+                if head_size_kq == 40:
+                    continue
                 if head_size_kq != 576 and ncols2 == 16:
                     continue
                 if head_size_kq == 576 and ncols2 != 16:
index 0e2b1847e09e239e37d982929b29babe94a6528b..934aefdcb45fa258984de24b110285f6192a3702 100644 (file)
@@ -53,6 +53,8 @@ file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
 list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
 
 file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
+file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
+list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
index f8477a2ef356da3eb53b0550828892e5b5c7a484..d76cb51977f9000cda679fef04b8d8e072a2bd09 100644 (file)
@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
     list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
 
     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
+    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
+    list(APPEND GGML_SOURCES_MUSA ${SRCS})
     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")