]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix FA occupancy, optimize tile kernel (llama/15982)
authorJohannes Gäßler <redacted>
Wed, 17 Sep 2025 13:32:42 +0000 (15:32 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-tile.cu
src/ggml-cuda/vendors/hip.h

index b0feea362380b4fecd5e222237ce72486c30a0ea..045c6d3006b2e084a90a7648e080c10db133bf45 100644 (file)
@@ -75,6 +75,8 @@
 #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
 #define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
 #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
+#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
 #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
 
 // Moore Threads
@@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
 #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
 }
 
+// Maximum number of bytes that can be copied in a single instruction.
+static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
+#ifdef GGML_USE_HIP
+    return 16;
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+    return 16;
+#else
+    return 8;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // GGML_USE_HIP
+}
+
+
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
index b69f57d659a266dbd626adf93640df63b672c086..142a3a88d1d7cd23326df54fec47109487e4f0be 100644 (file)
@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
 }
 
 template<int D> // D == head size
-#if !defined(GGML_USE_HIP)
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP)
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
     float VKQ_numerator   = 0.0f;
     float VKQ_denominator = 0.0f;
     for (int l = 0; l < parallel_blocks; ++l) {
-        const float diff = meta[l].x - kqmax;
-        float KQ_max_scale = expf(diff);
-        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
-        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+        const float KQ_max_scale = expf(meta[l].x - kqmax);
 
         VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
@@ -836,11 +831,10 @@ void launch_fattn(
         CUDA_CHECK(cudaGetLastError());
     }
 
-    int parallel_blocks = 1;
-
     const dim3 block_dim(warp_size, nwarps, 1);
     int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
     CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+    int parallel_blocks = max_blocks_per_sm;
 
     dim3 blocks_num;
     if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
         GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
         const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
 
-        // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
-        parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
-
         // parallel_blocks must not be larger than what the tensor size allows:
         parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
 
index c6a399ce5d791ca397232ac88621e920985f287e..a2d9951ea56a729b058682f484d06ce16d9d3646 100644 (file)
@@ -2,20 +2,30 @@
 #include "fattn-common.cuh"
 #include "fattn-tile.cuh"
 
-#define FATTN_TILE_NTHREADS 256
+// kq_stride == number of KQ rows to process per iteration
+// kq_nbatch == number of K columns to load in parallel for KQ calculation
 
 static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
     if (GGML_CUDA_CC_IS_AMD(cc)) {
+        if (GGML_CUDA_CC_IS_RDNA(cc)) {
+            switch (D) {
+                case 64:
+                    return 128;
+                case 128:
+                case 256:
+                    return ncols <= 16 ? 128 : 64;
+                default:
+                    GGML_ABORT("fatal error");
+                    return -1;
+            }
+        }
         switch (D) {
             case 64:
-                return 64;
+                return ncols == 32 ? 128 : 64;
             case 128:
+                return ncols == 32 ? 64 : 32;
             case 256:
-                if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
-                    return ncols <= 16 ? 64 : 32;
-                } else {
-                    return 64;
-                }
+                return 32;
             default:
                 GGML_ABORT("fatal error");
                 return -1;
@@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
 
 static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
 #ifdef GGML_USE_HIP
+#ifdef RDNA
     switch (D) {
         case 64:
-            return 64;
+            return 128;
         case 128:
-#if defined(GCN) || defined(CDNA)
-            return ncols <= 16 ? 64 : 32;
-#else
-            return 64;
-#endif // defined(GCN) || defined(CDNA)
         case 256:
-#if defined(GCN) || defined(CDNA)
-            return ncols <= 16 ? 64 : 32;
+            return ncols <= 16 ? 128 : 64;
+        default:
+            return -1;
+    }
 #else
-            return 64;
-#endif // defined(GCN) || defined(CDNA)
+    switch (D) {
+        case 64:
+            return ncols == 32 ? 128 : 64;
+        case 128:
+            return ncols == 32 ? 64 : 32;
+        case 256:
+            return 32;
         default:
             return -1;
     }
+#endif // RDNA
 #else
 #ifdef FAST_FP16_AVAILABLE
     switch (D) {
@@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
         case 64:
             return 64;
         case 128:
-#if defined(GCN) || defined(CDNA)
-            return ncols <= 16 ? 64 : 128;
-#else
-            return 64;
-#endif // defined(GCN) || defined(CDNA)
         case 256:
-#if defined(GCN) || defined(CDNA)
-            return ncols <= 16 ? 64 : 128;
-#else
-            return ncols <= 16 ? 64 : 256;
-#endif // defined(GCN) || defined(CDNA)
+            return 128;
         default:
             return -1;
     }
@@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
         case 64:
             return 64;
         case 128:
-            return ncols <= 16 ? 128 : 64;
         case 256:
-            return ncols <= 16 ? 64 : 128;
+            return 128;
         default:
             return -1;
     }
@@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
     GGML_UNUSED_VARS(ncols, warp_size);
 }
 
-template<int D, int ncols, bool use_logit_softcap> // D == head size
-#ifdef GGML_USE_HIP
-__launch_bounds__(FATTN_TILE_NTHREADS, 1)
+static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
+    return 256;
+    GGML_UNUSED_VARS(cc, ncols);
+}
+
+static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
+    return 256;
+    GGML_UNUSED(ncols);
+}
+
+static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
+#ifdef RDNA
+    return 3;
 #else
-__launch_bounds__(FATTN_TILE_NTHREADS, 2)
-#endif // GGML_USE_HIP
+    return ncols <= 16 ? 3 : 2;
+#endif // RDNA
+    GGML_UNUSED(ncols);
+}
+
+template<int D, int ncols, bool use_logit_softcap> // D == head size
+__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
 static __global__ void flash_attn_tile(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -193,7 +212,7 @@ static __global__ void flash_attn_tile(
     }
 
     constexpr int warp_size = 32;
-    constexpr int nwarps    = FATTN_TILE_NTHREADS / warp_size;
+    constexpr int nwarps    = fattn_tile_get_nthreads_device(ncols) / warp_size;
     constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
     static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
     constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
@@ -206,90 +225,126 @@ static __global__ void flash_attn_tile(
     const int sequence = blockIdx.z / ne02;
     const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
-    const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
-    const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
-    const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0);
-    const float  * sinksf = (const float  *) (sinks);
+    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2 * K_h2   = (const half2 *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2 * V_h2   = (const half2 *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
+    const float * sinksf = (const float *) (sinks);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
     const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
-#if defined(GGML_USE_HIP)
-    constexpr int cpy_nb = 16;
-#else
-    constexpr int cpy_nb = 8;
-#endif // defined(GGML_USE_HIP) && defined(GCN)
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
     constexpr int cpy_ne = cpy_nb / 4;
 
-    __shared__ float KQ[ncols][kq_stride];
+    constexpr int cpw = ncols/nwarps; // cols per warp
+
+    // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
+    // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
 #ifdef FAST_FP16_AVAILABLE
+    constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+
+    __shared__ half  KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
     __shared__ half2 Q_tmp[ncols][D/2];
-    __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
-    half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
+    __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
+    half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
 #else
+    constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+
+    __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
     __shared__ float Q_tmp[ncols][D];
-    __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
-    float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
-    float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
+    __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
+    float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
 #endif // FAST_FP16_AVAILABLE
+    static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
 
-
-    float kqmax[ncols/nwarps];
+    float KQ_max[cpw];
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        kqmax[j0/nwarps] = -FLT_MAX/2.0f;
+        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
     }
-    float kqsum[ncols/nwarps] = {0.0f};
+    float KQ_sum[cpw] = {0.0f};
 
+    // Load Q data, convert to FP16 if fast.
 #pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+    for (int j0 = 0; j0 < cpw; ++j0) {
+        const int j = j0 + threadIdx.y*cpw;
+
+        constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
 
 #pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
+        for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
+            float tmp_f[cpy_ne_D] = {0.0f};
+            if (ic0 + j < ne01) {
+                ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
+            }
+
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                tmp_f[i1] *= scale;
+            }
+
 #ifdef FAST_FP16_AVAILABLE
-            Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
+            half2 tmp_h2[cpy_ne_D/2];
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+                tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+            }
+            ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
 #else
-            Q_tmp[j][2*i0             + threadIdx.x] = tmp.x * scale;
-            Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
+            ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0   + threadIdx.x* cpy_ne_D],    tmp_f);
 #endif // FAST_FP16_AVAILABLE
         }
     }
 
     __syncthreads();
 
+    // Main loop over KV cache:
     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
     for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
-        float kqmax_new[ncols/nwarps];
+        float KQ_max_new[cpw];
 #pragma unroll
-        for (int j = 0; j < ncols/nwarps; ++j) {
-            kqmax_new[j] = kqmax[j];
+        for (int j = 0; j < cpw; ++j) {
+            KQ_max_new[j] = KQ_max[j];
         }
 
-        float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
+        float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
 
+        // KQ = K @ Q matrix multiplication:
 #pragma unroll
         for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
 #pragma unroll
             for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
                 const int i_KQ = i_KQ_0 + threadIdx.y;
 
-#pragma unroll
-                for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
-                    const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
 #ifdef FAST_FP16_AVAILABLE
-                    KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
+                constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
+#pragma unroll
+                for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
+                    ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
+                        &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
+                        &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
+                }
 #else
-                    const float2 tmp_f2 = __half22float2(tmp_h2);
-                    KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1             + threadIdx.x] = tmp_f2.x;
-                    KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
-#endif // FAST_FP16_AVAILABLE
+                constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
+#pragma unroll
+                for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
+                    half2 tmp_h2[cpy_ne_kqnb/2];
+                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+                        tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
+
+                    float2 tmp_f2[cpy_ne_kqnb/2];
+#pragma unroll
+                    for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
+                        tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
+                    }
+                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
+                        &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
                 }
+#endif // FAST_FP16_AVAILABLE
             }
 
             __syncthreads();
@@ -298,12 +353,12 @@ static __global__ void flash_attn_tile(
 #pragma unroll
             for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
                 half2 K_k[kq_stride/warp_size][cpy_ne];
-                half2 Q_k[ncols/nwarps][cpy_ne];
+                half2 Q_k[cpw][cpy_ne];
 #else
 #pragma unroll
             for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
                 float K_k[kq_stride/warp_size][cpy_ne];
-                float Q_k[ncols/nwarps][cpy_ne];
+                float Q_k[cpw][cpy_ne];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -311,29 +366,29 @@ static __global__ void flash_attn_tile(
                     const int i_KQ = i_KQ_0 + threadIdx.x;
 
 #ifdef FAST_FP16_AVAILABLE
-                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
+                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
 #else
-                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch   + cpy_ne) + k_KQ_1]);
+                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch   + cpy_ne) + k_KQ_1]);
 #endif // FAST_FP16_AVAILABLE
                 }
 #pragma unroll
-                for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
-                    const int j_KQ = j_KQ_0 + threadIdx.y;
+                for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
+                    const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
 
 #ifdef FAST_FP16_AVAILABLE
-                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
+                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
 #else
-                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0   + k_KQ_1]);
+                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0   + k_KQ_1]);
 #endif // FAST_FP16_AVAILABLE
                 }
 
 #pragma unroll
                 for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
 #pragma unroll
-                    for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                    for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
 #pragma unroll
                         for (int k = 0; k < cpy_ne; ++k) {
-                            ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
+                            ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
                         }
                     }
                 }
@@ -344,104 +399,77 @@ static __global__ void flash_attn_tile(
             }
         }
 
+        // Apply logit softcap, mask, update KQ_max:
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
             const int i_KQ = i_KQ_0 + threadIdx.x;
 
 #pragma unroll
-            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
-                const int j_KQ = j_KQ_0 + threadIdx.y;
+            for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
+                const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
 
                 if (use_logit_softcap) {
-                    sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
+                    KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
                 }
 
-                sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
-
-                kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
+                KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
 
-                KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
+                KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
             }
         }
 
         __syncthreads();
 
+        // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
 #pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
-            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
-            kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
-
-            float kqsum_add = 0.0f;
-            if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
-#pragma unroll
-                for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
-                    const int i = i0 + 4*threadIdx.x;
-
-                    float4 val = *(const float4 *) &KQ[j][i];
-                    val.x = expf(val.x - kqmax[j0/nwarps]);
-                    val.y = expf(val.y - kqmax[j0/nwarps]);
-                    val.z = expf(val.z - kqmax[j0/nwarps]);
-                    val.w = expf(val.w - kqmax[j0/nwarps]);
-                    kqsum_add += val.x + val.y + val.z + val.w;
-
+        for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
 #ifdef FAST_FP16_AVAILABLE
-                    const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
-                    ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
+            half  tmp[kq_stride/warp_size][softmax_iter_j];
 #else
-                    ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
+            float tmp[kq_stride/warp_size][softmax_iter_j];
 #endif // FAST_FP16_AVAILABLE
-                }
-            } else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
+
 #pragma unroll
-                for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
-                    const int i = i0 + 2*threadIdx.x;
+            for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
+                KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
+                const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
+                KQ_max[j0+j1] = KQ_max_new[j0+j1];
 
-                    float2 val = *(const float2 *) &KQ[j][i];
-                    val.x = expf(val.x - kqmax[j0/nwarps]);
-                    val.y = expf(val.y - kqmax[j0/nwarps]);
-                    kqsum_add += val.x + val.y;
-#ifdef FAST_FP16_AVAILABLE
-                    const half2 tmp = make_half2(val.x, val.y);
-                    ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
-#else
-                    ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
-#endif // FAST_FP16_AVAILABLE
-                }
-            } else {
+                float KQ_sum_add = 0.0f;
+#pragma unroll
                 for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
-                    const int i = i0 + threadIdx.x;
+                    const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
+                    KQ_sum_add += val;
+                    tmp[i0/warp_size][j1] = val;
+                }
+                KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
 
-                    const float diff = KQ[j][i] - kqmax[j0/nwarps];
-                    const float val = expf(diff);
-                    kqsum_add += val;
 #ifdef FAST_FP16_AVAILABLE
-                    ((half *) KQ[j])[i] = val;
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
+                }
 #else
-                    KQ[j][i] = val;
-#endif // FAST_FP16_AVAILABLE
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
+                    VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
                 }
+#endif // FAST_FP16_AVAILABLE
             }
-            kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
 
-#ifdef FAST_FP16_AVAILABLE
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
 #pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
-            }
-#else
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
-                VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
+            for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
+                const int i = i0 + threadIdx.x;
+
+                ggml_cuda_memcpy_1<sizeof(tmp[0])>(
+                    KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
             }
-#endif // FAST_FP16_AVAILABLE
         }
 
-        constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
+        // VKQ = V @ KQ matrix multiplication:
+        constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
         static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
 #pragma unroll
         for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
@@ -449,65 +477,96 @@ static __global__ void flash_attn_tile(
             for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
                 const int k_tile = k1 + threadIdx.y;
 
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                    const int i = i0 + threadIdx.x;
-
-                    const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
 #ifdef FAST_FP16_AVAILABLE
-                    KV_tmp_h2[k_tile*(D/2) + i] = tmp;
+                constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
+                    ggml_cuda_memcpy_1<cpy_ne_D*4>(
+                        &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
+                        &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
+                }
 #else
-                    KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
-#endif // FAST_FP16_AVAILABLE
+                constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
+#pragma unroll
+                for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
+                    half2 tmp_h2[cpy_ne_D/2];
+                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+                        tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
+
+                    float2 tmp_f2[cpy_ne_D/2];
+#pragma unroll
+                    for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+                        tmp_f2[i1] = __half22float2(tmp_h2[i1]);
+                    }
+                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
+                        &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
                 }
+#endif // FAST_FP16_AVAILABLE
             }
 
             __syncthreads();
 
+#ifdef FAST_FP16_AVAILABLE
 #pragma unroll
             for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
-#ifdef FAST_FP16_AVAILABLE
                 half2 V_k[(D/2)/warp_size];
-                half2 KQ_k[ncols/nwarps];
-#else
-                float2 V_k[(D/2)/warp_size];
-                float  KQ_k[ncols/nwarps];
-#endif // FAST_FP16_AVAILABLE
+                half2 KQ_k[cpw];
 
+                constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
 #pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                    const int i = i0 + threadIdx.x;
+                for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
+                    ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
+                }
+#pragma unroll
+                for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
+                    const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
 
-#ifdef FAST_FP16_AVAILABLE
-                    V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
+                    half tmp[softmax_iter_j];
+                    ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
+                        &tmp, KQ[j][k0 + k1]);
+#pragma unroll
+                    for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
+                        KQ_k[j0+j1] = __half2half2(tmp[j1]);
+                    }
+                }
+
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+#pragma unroll
+                    for (int j0 = 0; j0 < cpw; ++j0) {
+                        VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
+                    }
+                }
+            }
 #else
-                    V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
-#endif // FAST_FP16_AVAILABLE
+#pragma unroll
+            for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
+                float2 V_k[(D/2)/warp_size];
+                float  KQ_k[cpw];
+
+                constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
+#pragma unroll
+                for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
+                    ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
                 }
 #pragma unroll
-                for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-                    const int j = j0 + threadIdx.y;
+                for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
+                    const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
 
-#ifdef FAST_FP16_AVAILABLE
-                    KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
-#else
-                    KQ_k[j0/nwarps] = KQ[j][k0 + k1];
-#endif // FAST_FP16_AVAILABLE
+                    ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
+                        &KQ_k[j0], KQ[j][k0 + k1]);
                 }
 
 #pragma unroll
                 for (int i0 = 0; i0 < D/2; i0 += warp_size) {
 #pragma unroll
-                    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-#ifdef FAST_FP16_AVAILABLE
-                        VKQ[j0/nwarps][i0/warp_size]   += V_k[i0/warp_size]  *KQ_k[j0/nwarps];
-#else
-                        VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
-                        VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
-#endif // FAST_FP16_AVAILABLE
+                    for (int j0 = 0; j0 < cpw; ++j0) {
+                        VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
+                        VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
                     }
                 }
             }
+#endif // FAST_FP16_AVAILABLE
 
             __syncthreads();
         }
@@ -519,69 +578,92 @@ static __global__ void flash_attn_tile(
         const float sink = sinksf[head];
 
 #pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
-            kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
+        for (int j0 = 0; j0 < cpw; ++j0) {
+            float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
+            KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
 
-            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
-            kqmax[j0/nwarps] = kqmax_new_j;
+            const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
+            KQ_max[j0] = KQ_max_new_j;
 
-            const float val = expf(sink - kqmax[j0/nwarps]);
-            kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+            const float val = expf(sink - KQ_max[j0]);
+            KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
             if (threadIdx.x == 0) {
-                kqsum[j0/nwarps] += val;
+                KQ_sum[j0] += val;
             }
 
 #ifdef FAST_FP16_AVAILABLE
             const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
 #pragma unroll
             for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
+                VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
             }
 #else
 #pragma unroll
             for (int i0 = 0; i0 < D/2; i0 += warp_size) {
-                VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
-                VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
+                VKQ[j0][i0/warp_size].x *= KQ_max_scale;
+                VKQ[j0][i0/warp_size].y *= KQ_max_scale;
             }
 #endif // FAST_FP16_AVAILABLE
         }
     }
 
-    float2 * dst2 = (float2 *) dst;
+#pragma unroll
+    for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
+        KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
+    }
+    if (gridDim.y == 1) {
+#pragma unroll
+        for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
+#ifdef FAST_FP16_AVAILABLE
+            const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
+#pragma unroll
+            for (int i = 0; i < (D/2)/warp_size; ++i) {
+                VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
+            }
+#else
+            const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
+#pragma unroll
+            for (int i = 0; i < (D/2)/warp_size; ++i) {
+                VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
+                VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
+            }
+#endif // FAST_FP16_AVAILABLE
+        }
+    }
 
+    // Write back results:
 #pragma unroll
-    for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
-        const int j_VKQ = j_VKQ_0 + threadIdx.y;
+    for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
+        const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
 
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
 
-        float kqsum_j = kqsum[j_VKQ_0/nwarps];
-        kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
-
         const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
 
-#pragma unroll
-        for (int i00 = 0; i00 < D/2; i00 += warp_size) {
-            const int i0 = i00 + threadIdx.x;
-
 #ifdef FAST_FP16_AVAILABLE
-            float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
-#else
-            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
-#endif // FAST_FP16_AVAILABLE
-
-            if (gridDim.y == 1) {
-                dst_val.x /= kqsum_j;
-                dst_val.y /= kqsum_j;
+        constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
+            float2 tmp[cpy_ne_D];
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
             }
-            dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
+            ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
         }
+#else
+        constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
+            ggml_cuda_memcpy_1<cpy_ne_D*4>(
+                &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
+        }
+#endif // FAST_FP16_AVAILABLE
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
         }
     }
 #else
@@ -602,15 +684,29 @@ template <int D, bool use_logit_softcap>
 static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
 
-    const int id                 = ggml_cuda_get_device();
-    const int cc                 = ggml_cuda_info().devices[id].cc;
-    const int warp_size          = 32;
-    const int nwarps             = FATTN_TILE_NTHREADS / warp_size;
+    const int id        = ggml_cuda_get_device();
+    const int cc        = ggml_cuda_info().devices[id].cc;
+    const int warp_size = 32;
 
     constexpr size_t nbytes_shared = 0;
 
+#ifdef GGML_USE_HIP
+    if constexpr (D <= 128) {
+        if (Q->ne[1] > 32) {
+            constexpr int cols_per_block = 64;
+            const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
+            fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
+            const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
+            launch_fattn<D, cols_per_block, 1>
+                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
+            return;
+        }
+    }
+#endif // GGML_USE_HIP
+
     if (Q->ne[1] > 16) {
         constexpr int cols_per_block = 32;
+        const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
         fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
         const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
         launch_fattn<D, cols_per_block, 1>
@@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
     }
 
     constexpr int cols_per_block = 16;
+    const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
     fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
     const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
     launch_fattn<D, cols_per_block, 1>
index 12bbee45566de393509d9c91400d3329d8161740..37386afcd405b73338055435a5fd68d60fc27226 100644 (file)
 
 #define __CUDA_ARCH__ 1300
 
-#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
-#define GCN
-#endif
-
 #if defined(__gfx900__) || defined(__gfx906__)
 #define GCN5
-#endif
+#endif // defined(__gfx900__) || defined(__gfx906__)
 
 #if defined(__gfx803__)
 #define GCN4
-#endif
+#endif // defined(__gfx803__)
 
-#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
-#define CDNA // For the entire family
-#endif
+#if defined(GCN5) || defined(GCN4)
+#define GCN
+#endif // defined(GCN5) || defined(GCN4)
 
 #if defined(__gfx942__)
 #define CDNA3
-#endif
+#endif // defined(__gfx942__)
 
 #if defined(__gfx90a__)
 #define CDNA2
-#endif
+#endif // defined(__gfx90a__)
 
 #if defined(__gfx908__)
 #define CDNA1
-#endif
+#endif // defined(__gfx908__)
+
+#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
+#define CDNA // For the entire family
+#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
 
 #if defined(__GFX12__)
 #define RDNA4
-#endif
+#endif // defined(__GFX12__)
 
 #if defined(__GFX11__)
 #define RDNA3
-#endif
+#endif // defined(__GFX11__)
 
 #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
     defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
 
 #if defined(__gfx1010__) || defined(__gfx1012__)
 #define RDNA1
-#endif
+#endif // defined(__gfx1010__) || defined(__gfx1012__)
+
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
+#define RDNA // For the entire family
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
 
 #ifndef __has_builtin
     #define __has_builtin(x) 0