]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : add high-throughput mode (llama/14363)
authorGeorgi Gerganov <redacted>
Wed, 16 Jul 2025 13:35:42 +0000 (16:35 +0300)
committerGeorgi Gerganov <redacted>
Sat, 19 Jul 2025 14:47:23 +0000 (17:47 +0300)
* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (llama/14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/ggml-cuda.cu
tests/test-backend-ops.cpp

index 075f14a49e9ac91ce131370da5cac7b941937dfa..9122fca6cf99f3ed475167de211716a260bf70bb 100644 (file)
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int channel = kbc0 / (iter_k*iter_j);
-    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
     if (jt*ncols1 + j >= ne01) {
         return;
     }
 
-    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
         const float2 * __restrict__ VKQ_meta,
         float * __restrict__ dst,
         const int parallel_blocks) {
-    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
-    dst       +=                 D * gridDim.z*blockIdx.x;
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    const int ne01 = gridDim.x;
+    const int ne02 = gridDim.y;
+
+    const int col      = blockIdx.x;
+    const int head     = blockIdx.y;
+    const int sequence = blockIdx.z;
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
     for (int i = tid; i < 2*parallel_blocks; i += D) {
-        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
     }
 
     __syncthreads();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[tid] = VKQ_numerator / VKQ_denominator;
 }
 
 [[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
-    GGML_ASSERT(Q->ne[3] == 1);
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -853,8 +869,8 @@ void launch_fattn(
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
-        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
         nb11, nb12, nb13,
         nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
         flash_attn_combine_results<DV>
index 709589854f0afa7c2b89f725b2330d568652e045..6fa2e77299eb0c648c165a356e4efd40ca93ec7b 100644 (file)
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
     int kb0_start = kbc % iter_k;
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int channel = kbc / (iter_k*iter_j);
-        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+        const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
         const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+            (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
         const int kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int channel = kbc / (iter_k*iter_j);
-    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
     const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+        (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
     const int kb0_stop_kernel  = kb0_stop  * kb_niter;
index 0c967f178e7b17a7e9b42d505e87bcda961ef62f..1f141328845a485e1c449d5a85e5a59a5ac733af 100644 (file)
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum((float)kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
+            dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index 908c76dbdd270431e433963312eae3e759c463b6..a4965583cef1cfd00f8488348c1935d470b9d15d 100644 (file)
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
+            dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
index e78fb181919fda704f5f095c2f633e30e8dae0f5..b2d469938abf29d03b263fb1c56e7688a422481b 100644 (file)
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio);
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index b2f1724c955886858712bc6d66003d0456d3d32c..405b6f5106ea0761c608fe9fc77d1c0c4d1fbe74 100644 (file)
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
index c95ca7b1f285f54ede3d93c174fe2172e2d6f943..741b8781d29f5727e3076c36d70785f7941156ea 100644 (file)
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float * Q_f   = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half  * V_h   = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
     const half2 * mask2 = (const half2 *)  maskh;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same<KQ_acc_t, float>::value) {
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
             if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
+            dst[j_dst_unrolled*D + i] = dst_val;
         }
 
         if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
+        dst_meta[j_dst_unrolled] = dst_meta_val;
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
+    GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
index 8015b0d4e8d9245b439b5bd0a354d303411f7937..778d5a48bd9f85bbcbee26056370dc205df59cd8 100644 (file)
@@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
-            // TODO: support broadcast
-            // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
-            //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
@@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
+            if (op->src[3] && op->src[3]->ne[2] != 1) {
+                return false;
+            }
             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
index 81fe90b99323d7768324ba2a13777d1733fcd408..a3d68fba046cf5e207e45e12662f518c61be5a5d 100644 (file)
@@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case {
 
         ggml_tensor * m = nullptr;
         if (mask) {
-            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
             ggml_set_name(m, "m");
         }