]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix overflow in FA, tune performance (llama/14840)
authorJohannes Gäßler <redacted>
Wed, 23 Jul 2025 19:43:25 +0000 (21:43 +0200)
committerGeorgi Gerganov <redacted>
Thu, 24 Jul 2025 17:57:40 +0000 (20:57 +0300)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/fattn.cu

index 3644ddf2fdf360f852be265161b65f07431639c7..95e704e393c2a60b975c97549936b1a5b9698fec 100644 (file)
@@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3);
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33);
 
 typedef half (*vec_dot_KQ_f16_t)(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -892,14 +872,11 @@ void launch_fattn(
         mask ? ((const char *) mask->data) : nullptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
-        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
-        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
-        Q->nb[1], Q->nb[2], Q->nb[3],
-        nb11, nb12, nb13,
+        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+        K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
         nb21, nb22, nb23,
-        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
     );
     CUDA_CHECK(cudaGetLastError());
 
index 6fa2e77299eb0c648c165a356e4efd40ca93ec7b..565853bfecdef8c60603d3867f1ba5ac0243842e 100644 (file)
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int stride_K,
         const int stride_V,
         const int stride_mask,
-        const int jt,
         half2        * const __restrict__ tile_Q,
         half2        * const __restrict__ tile_K,
         half2        * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         cp_async_wait_all();
         __syncthreads();
         flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
-            (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
+            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
     } else {
         constexpr bool use_cp_async = nstages == 1;
         if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         if (nstages <= 1) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-                (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
+                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
             if (use_cp_async) {
                 cp_async_wait_all();
             }
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
             }
             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-                (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
+                (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
         }
     }
 
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         if (nstages <= 1 && i0_start < reusable_cutoff) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
-                (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
+                (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
             if (use_cp_async) {
                 cp_async_wait_all();
             }
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
-    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
-    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
         }
         flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-            (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
+            (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
     }
 
     // Iterate over ne11 == previous tokens:
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = false;
         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
         constexpr bool last_iter = true;
         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
     }
 
     // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
-    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
index 1f141328845a485e1c449d5a85e5a59a5ac733af..7661c21efbbdd9840fd9f2f851eba09e1e09c553 100644 (file)
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
@@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
                 const int k_KQ = k_KQ_0 + threadIdx.x;
 
-                KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+                KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
             }
         }
 
@@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+                KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
             }
         }
 
@@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(nb23);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
index a4965583cef1cfd00f8488348c1935d470b9d15d..2e2ed5cd566a2c3edacd9297c70b2a479b7f30ca 100644 (file)
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
@@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+        GGML_UNUSED(nb23);
         NO_DEVICE_CODE;
         return;
     }
@@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
-                const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+                const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
                 KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp);
                 KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
             }
@@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                KV_tmp2[k*(D/2) + i].x =  __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
-                KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+                const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
+                KV_tmp2[k*(D/2) + i].x =  __low2float(tmp);
+                KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
             }
         }
 
@@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
index b2d469938abf29d03b263fb1c56e7688a422481b..f6ef236be98100cc35c803200b7dcc16dcff846c 100644 (file)
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
@@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
+    K     += blockIdx.y*D * nb11;
+    V     += blockIdx.y*D * nb21;
+    maskh += blockIdx.y*D;
     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
+                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
             }
 
             __syncthreads();
@@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+                half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum((float)sum);
 
                 if (use_logit_softcap) {
@@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
             }
 
             half2 V_k;
-            reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
-            reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+            reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
+            reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
             }
         }
 
+        K     += gridDim.y*D * nb11;
+        V     += gridDim.y*D * nb21;
+        maskh += gridDim.y*D;
+
         __syncthreads();
     }
 
@@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(nb23);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
index 405b6f5106ea0761c608fe9fc77d1c0c4d1fbe74..6a4bdc0ff9aac847f4e3b4e74aefdbe39bddaa6c 100644 (file)
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
@@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+        GGML_UNUSED(nb23);
         NO_DEVICE_CODE;
         return;
     }
@@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
 
     float VKQ[ncols] = {0.0f};
 
+    K     += blockIdx.y*D * nb11;
+    V     += blockIdx.y*D * nb21;
+    maskh += blockIdx.y*D;
     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
+                maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
             }
 
             __syncthreads();
@@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum(sum);
 
                 if (use_logit_softcap) {
@@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
                 break;
             }
 
-            const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+            const float V_ki = dequantize_1_v(V + k*nb21, tid);
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 VKQ[j] += V_ki*KQ[j*D + k];
             }
         }
 
+        K     += gridDim.y*D * nb11;
+        V     += gridDim.y*D * nb21;
+        maskh += gridDim.y*D;
+
         __syncthreads();
     }
 
@@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32(
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
index 741b8781d29f5727e3076c36d70785f7941156ea..c9b083bed014bb717cc126eb0f4ccde125dea3e3 100644 (file)
@@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int ne33,
-        const int nb31,
-        const int nb32,
-        const int nb33,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
 }
index 6bc0096cc65e6229caf18879b07571fec1346f23..d9f1613051d3a14085f5498de0b05a79b9260845 100644 (file)
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (GGML_CUDA_CC_IS_AMD(cc)) {
 #if defined(GGML_HIP_ROCWMMA_FATTN)
-        if (fp16_mma_available(cc)) {
-            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
-            return;
-        }
-#endif // defined(GGML_HIP_ROCWMMA_FATTN)
-
-        // On AMD the tile kernels perform poorly, use the vec kernel instead:
-        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        }
+    if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
+        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
 
     if (!fast_fp16_available(cc)) {
         if (Q->ne[1] <= 8 || Q->ne[0] == 256) {