]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: broadcasting for FlashAttention mask (llama/14500)
authorJohannes Gäßler <redacted>
Wed, 2 Jul 2025 11:42:12 +0000 (13:42 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cu

index cfab2b5ebaccc6000ddc0e10a55e1ac18bbedd51..075f14a49e9ac91ce131370da5cac7b941937dfa 100644 (file)
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -851,7 +853,8 @@ void launch_fattn(
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
         nb11, nb12, nb13,
         nb21, nb22, nb23,
index e230f6d494d77fdceec51d94ad79bcff46d03601..709589854f0afa7c2b89f725b2330d568652e045 100644 (file)
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
 
         const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
         const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-        const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+        const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
+            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
         float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
         const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
 
     const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
     const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-    const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+    const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
+        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
     float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
     const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
-    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
     GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
     GGML_UNUSED(ne2); GGML_UNUSED(ne3);
index 9283560d5c4ee3e4e0adb37fd8eeee1124967a53..0c967f178e7b17a7e9b42d505e87bcda961ef62f 100644 (file)
@@ -6,7 +6,7 @@
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+__launch_bounds__(nwarps*WARP_SIZE, 2)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
     const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
     const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
     const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *)  mask + ne11*ic0;
+    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index 32673adb57fc1d7d74d9830372baf72f92c5504d..124d5d3e89122ae1debef81ae88cdc4270eda7ab 100644 (file)
@@ -6,7 +6,7 @@
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+__launch_bounds__(nwarps*WARP_SIZE, 2)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
     const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
     const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
     const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *)  mask + ne11*ic0;
+    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
index 35e649cb3c81bdbbd81914bc97345c654db3e65f..e78fb181919fda704f5f095c2f633e30e8dae0f5 100644 (file)
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
     K += nb12*(blockIdx.z / gqa_ratio);
     V += nb22*(blockIdx.z / gqa_ratio);
 
-    const half * maskh = (const half   *)  mask + ne11*ic0;
+    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
 
     const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index 95396791779698ce15e692ef8b8e32bbc99ac396..c22baf41764d168e78e60e61b8ee4f3d11ce6193 100644 (file)
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
     Q += nb02* blockIdx.z              + nb01*ic0;
     K += nb12*(blockIdx.z / gqa_ratio);
     V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
-    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
 
     const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 
index f3b794c3644c8a5bb7343887a15eee1f13c6f034..c95ca7b1f285f54ede3d93c174fe2172e2d6f943 100644 (file)
@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const half2 * mask2 = (const half2 *)  maskh;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);