]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add high-throughput mode (#14363)
authorGeorgi Gerganov <redacted>
Wed, 16 Jul 2025 13:35:42 +0000 (16:35 +0300)
committerGitHub <redacted>
Wed, 16 Jul 2025 13:35:42 +0000 (16:35 +0300)
* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (#14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <redacted>
30 files changed:
common/arg.cpp
common/common.cpp
common/common.h
examples/embedding/embedding.cpp
examples/parallel/parallel.cpp
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/ggml-cuda.cu
include/llama.h
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-cparams.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory-hybrid.cpp
src/llama-model.cpp
tests/test-backend-ops.cpp
tools/batched-bench/batched-bench.cpp

index 4c86f58f2cc33f4775193e4898a25879227a8a72..c1151f51da17b380acf1392b1330e60627196d05 100644 (file)
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.swa_full = true;
         }
     ).set_env("LLAMA_ARG_SWA_FULL"));
+    add_opt(common_arg(
+        {"--kv-unified", "-kvu"},
+        string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
+        [](common_params & params) {
+            params.kv_unified = true;
+        }
+    ).set_env("LLAMA_ARG_KV_SPLIT"));
     add_opt(common_arg(
         {"--no-context-shift"},
         string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
index 262b67998fd11c2299be0f74121803a0433abec6..466271be61c630b6d8ea98d2921082d80f2904a1 100644 (file)
@@ -1163,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.no_perf           = params.no_perf;
     cparams.op_offload        = !params.no_op_offload;
     cparams.swa_full          = params.swa_full;
+    cparams.kv_unified        = params.kv_unified;
 
     cparams.type_k = params.cache_type_k;
     cparams.type_v = params.cache_type_v;
index e1f272318df7694af1a297708b18914324d6a90a..27adf552465e77aec5fb70a40dbc8ddab4cb0d50 100644 (file)
@@ -341,6 +341,7 @@ struct common_params {
     bool no_perf           = false; // disable performance metrics
     bool ctx_shift         = true;  // context shift on inifinite text generation
     bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
+    bool kv_unified        = false; // enable unified KV cache
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool use_mmap          = true;  // use mmap for faster loads
index 0ec2999a0c8e90d6cfe5809e08f5de0e99496b0b..40ff6483807ee58485be09026973d147a6deff61 100644 (file)
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
     const int n_ctx_train = llama_model_n_ctx_train(model);
-    const int n_ctx = llama_n_ctx(ctx);
+    const int n_ctx       = llama_n_ctx(ctx);
 
     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
 
index d53e089a4cbc2a72153c8c3934487d0d6c118395..46fb451baa7122dd603bb6df12b208e6a98003d3 100644 (file)
@@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
         auto & client = clients[i];
         client.id = i;
         client.smpl = common_sampler_init(model, params.sampling);
+        //params.sampling.seed++;
     }
 
     std::vector<llama_token> tokens_system;
@@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
                     client.n_decoded = 0;
                     client.i_batch   = batch.n_tokens - 1;
 
-                    LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
+                    LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
 
                     g_seq_id += 1;
 
index 075f14a49e9ac91ce131370da5cac7b941937dfa..9122fca6cf99f3ed475167de211716a260bf70bb 100644 (file)
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int channel = kbc0 / (iter_k*iter_j);
-    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
     if (jt*ncols1 + j >= ne01) {
         return;
     }
 
-    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
         const float2 * __restrict__ VKQ_meta,
         float * __restrict__ dst,
         const int parallel_blocks) {
-    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
-    dst       +=                 D * gridDim.z*blockIdx.x;
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    const int ne01 = gridDim.x;
+    const int ne02 = gridDim.y;
+
+    const int col      = blockIdx.x;
+    const int head     = blockIdx.y;
+    const int sequence = blockIdx.z;
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
     for (int i = tid; i < 2*parallel_blocks; i += D) {
-        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
     }
 
     __syncthreads();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[tid] = VKQ_numerator / VKQ_denominator;
 }
 
 [[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
-    GGML_ASSERT(Q->ne[3] == 1);
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -853,8 +869,8 @@ void launch_fattn(
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
-        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
         nb11, nb12, nb13,
         nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
         flash_attn_combine_results<DV>
index 709589854f0afa7c2b89f725b2330d568652e045..6fa2e77299eb0c648c165a356e4efd40ca93ec7b 100644 (file)
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
     int kb0_start = kbc % iter_k;
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int channel = kbc / (iter_k*iter_j);
-        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+        const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
         const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+            (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
         const int kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int channel = kbc / (iter_k*iter_j);
-    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
     const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+        (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
     const int kb0_stop_kernel  = kb0_stop  * kb_niter;
index 0c967f178e7b17a7e9b42d505e87bcda961ef62f..1f141328845a485e1c449d5a85e5a59a5ac733af 100644 (file)
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum((float)kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
+            dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index 908c76dbdd270431e433963312eae3e759c463b6..a4965583cef1cfd00f8488348c1935d470b9d15d 100644 (file)
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
+            dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
index e78fb181919fda704f5f095c2f633e30e8dae0f5..b2d469938abf29d03b263fb1c56e7688a422481b 100644 (file)
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio);
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
index b2f1724c955886858712bc6d66003d0456d3d32c..405b6f5106ea0761c608fe9fc77d1c0c4d1fbe74 100644 (file)
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
index c95ca7b1f285f54ede3d93c174fe2172e2d6f943..741b8781d29f5727e3076c36d70785f7941156ea 100644 (file)
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float * Q_f   = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half  * V_h   = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
     const half2 * mask2 = (const half2 *)  maskh;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same<KQ_acc_t, float>::value) {
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
             if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
+            dst[j_dst_unrolled*D + i] = dst_val;
         }
 
         if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
+        dst_meta[j_dst_unrolled] = dst_meta_val;
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
+    GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
index 8015b0d4e8d9245b439b5bd0a354d303411f7937..778d5a48bd9f85bbcbee26056370dc205df59cd8 100644 (file)
@@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
-            // TODO: support broadcast
-            // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
-            //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
@@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
+            if (op->src[3] && op->src[3]->ne[2] != 1) {
+                return false;
+            }
             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
index bbe4f8dbfae666efbca5b0b7830abbb7dcef3ff1..db6a5337b02a71f8b1804badbe31232a03093a0f 100644 (file)
@@ -335,6 +335,9 @@ extern "C" {
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
                           // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
                           //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
+        bool kv_unified;  // use a unified buffer across the input sequences when computing the attention
+                          // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
+                          // ref: https://github.com/ggml-org/llama.cpp/pull/14363
     };
 
     // model quantization parameters
index 3bc8554e51ccf518e781ba5076780ae757c294a9..f8227777f19de83eb6ed6b302165b32bb49fc8a5 100644 (file)
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
         const llama_vocab & vocab,
         const llama_memory_i * memory,
         uint32_t n_embd,
+        uint32_t n_seq_max,
         bool output_all) {
     clear();
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
     // validate input batch
     //
 
+    if (n_seq_max > LLAMA_MAX_SEQ) {
+        LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
+        return false;
+    }
+
     if (batch.token) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
     if (batch.seq_id) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
+                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
                     return false;
                 }
             }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
 
         // initialize the starting position for each sequence based on the positions in the memory
         llama_pos p0[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (!memory) {
                 // if no memory -> start from 0
                 p0[s] = 0;
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
     // compute stats
     //
 
-    this->n_embd = n_embd;
+    this->n_embd    = n_embd;
+    this->n_seq_max = n_seq_max;
 
     // count the outputs in this batch
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
             seq_set_map[cur].push_back(i);
         }
 
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (seq_set_unq.test(s)) {
                 seq_idx[s] = seq_id_unq.size();
                 seq_id_unq.push_back(s);
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
     // consistency checks
     //
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_pos[s].empty()) {
             continue;
         }
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
     }
 
     if (memory) {
-        for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
-            for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
+        for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
+            for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
                 if (seq_cpl[s0][s1]) {
                     if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
                         memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
     //
     {
         seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_set[s].set();
         }
 
         llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_pos[s] = -1;
         }
 
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
     }
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
             ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
             ubatch.seq_id_unq.push_back(s);
index 3420803ff946967319e96947497b315ae74d1f6c..1a24440ba7562e6111c4835c981e065f9e6a46f4 100644 (file)
@@ -48,6 +48,7 @@ public:
             const llama_vocab & vocab,
             const llama_memory_i * memory,
             uint32_t n_embd,
+            uint32_t n_seq_max,
             bool output_all);
 
     const llama_batch & get_batch() const;
@@ -100,6 +101,7 @@ private:
     const uint32_t n_pos_per_embd;
 
     uint32_t n_embd;
+    uint32_t n_seq_max;
     uint32_t n_outputs;
 
     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
index 7c07b047b0dd9a6faaff2282f9c3acb1a0acd8b2..840ec9a9aaca1da466ba558cd068afd8f94d8ab2 100644 (file)
@@ -98,10 +98,20 @@ llama_context::llama_context(
         LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
         cparams.n_batch = GGML_KQ_MASK_PAD;
     }
-
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
     cparams.op_offload = params.op_offload;
+    cparams.kv_unified = params.kv_unified;
+
+    {
+        const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
+        const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+
+        if (!supports_set_rows && !cparams.kv_unified) {
+            LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
+            cparams.kv_unified = true;
+        }
+    }
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
@@ -112,6 +122,7 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
     LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: kv_unified    = %s\n",   __func__, cparams.kv_unified ? "true" : "false");
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
 
@@ -267,7 +278,7 @@ llama_context::llama_context(
 
     // reserve worst-case graph
     if (!hparams.vocab_only && memory) {
-        const uint32_t n_seqs = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -300,7 +311,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, mctx.get());
+            auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -311,6 +322,10 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
+            // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
+            //
+            // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
+            //
             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
@@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
             throw std::runtime_error("failed to initialize memory context");
         }
 
-        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
     const int32_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
+    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
 
     const uint32_t n_tokens = balloc->get_n_tokens();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
     const llama_ubatch ubatch = balloc->split_simple(n_tokens);
 
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // when computing embeddings, all tokens are output
     const bool output_all = cparams.embeddings;
 
-    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
+    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
             batch.logits  [pos_batch]    = true;
         }
 
-        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
+        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
             LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
             return;
         }
@@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
+        /*.kv_unified                  =*/ false,
     };
 
     return result;
index 118615d5bd2d59f4e640c661592cf91ee0d3fda3..38750affc500b74504f53bf65320332ddcef0d09 100644 (file)
@@ -11,8 +11,8 @@ struct llama_cparams {
     uint32_t n_batch;
     uint32_t n_ubatch;
     uint32_t n_seq_max;
-    int      n_threads;       // number of threads to use for generation
-    int      n_threads_batch; // number of threads to use for batch processing
+    int32_t  n_threads;       // number of threads to use for generation
+    int32_t  n_threads_batch; // number of threads to use for batch processing
 
     float rope_freq_base;
     float rope_freq_scale;
@@ -33,6 +33,7 @@ struct llama_cparams {
     bool no_perf;
     bool warmup;
     bool op_offload;
+    bool kv_unified;
 
     enum llama_pooling_type pooling_type;
 
index a248a7ec22350898ac2f31ba9d817e528651b925..1a6355e85d11ea651255c8b46ca3afa872e64249 100644 (file)
@@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
              float     kq_scale) const {
     const bool v_trans = v->nb[1] > v->nb[2];
 
+    // split the batch into streams if needed
+    const auto n_stream = k->ne[3];
+
+    q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
+
     q = ggml_permute(ctx0, q, 0, 2, 1, 3);
     k = ggml_permute(ctx0, k, 0, 2, 1, 3);
     v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
-    const auto n_tokens = q->ne[1];
-    const auto n_head   = q->ne[2];
-    const auto n_kv     = k->ne[1];
+    const auto n_kv = k->ne[1];
 
     ggml_tensor * cur;
 
@@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 #endif
         }
 
-        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
 
@@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
-        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        // recombine streams
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
 
         if (!cparams.offload_kqv) {
             // all nodes between the KV store and the attention output are run on the CPU
@@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
+    assert(ubatch.equal_seqs == false);
+
     ggml_tensor * q = q_cur;
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
@@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = mctx_cur->get_n_kv();
+        const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
+        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
     {
         const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask_swa);
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
index fbf8e2889564ddb591af4bad939857305737f36a..84a5b0b3f9c4033be8943301944870f814b956cf 100644 (file)
@@ -255,10 +255,10 @@ public:
     ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 
     ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -289,14 +289,14 @@ public:
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
     ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
     ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
index 7aa736e2f39db9ca876dde6a791daff8bd1e5bea..c6c67d26f9392cd4c81e50160e52690ec96b7f10 100644 (file)
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
     return n_embd_head_v * n_head_kv;
 }
 
+bool llama_hparams::is_n_embd_k_gqa_variable() const {
+    const uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_k_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+bool llama_hparams::is_n_embd_v_gqa_variable() const {
+    const uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_v_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+uint32_t llama_hparams::n_embd_k_gqa_max() const {
+    uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_k_gqa(il));
+    }
+
+    return val;
+}
+
+uint32_t llama_hparams::n_embd_v_gqa_max() const {
+    uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_v_gqa(il));
+    }
+
+    return val;
+}
+
 uint32_t llama_hparams::n_embd_r() const {
     if (wkv_head_size != 0) {
         // for RWKV models
index 9116a3743c993957225e0060ccfdab811fd5f497..c422cd7be827a08deda24191903853c036a6f856 100644 (file)
@@ -191,6 +191,14 @@ struct llama_hparams {
     // dimension of value embeddings across all k-v heads
     uint32_t n_embd_v_gqa(uint32_t il = 0) const;
 
+    // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
+    bool is_n_embd_k_gqa_variable() const;
+    bool is_n_embd_v_gqa_variable() const;
+
+    // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
+    uint32_t n_embd_k_gqa_max() const;
+    uint32_t n_embd_v_gqa_max() const;
+
     // dimension of the rolling state embeddings
     // corresponds to Mamba's conv_states size or RWKV's token_shift states size
     uint32_t n_embd_r() const;
index fe207ad536032d34930bc6f4bbfc4a35b4226bb5..01d27fb4db9b1d8adb104432b8c5c64f3b2ece7c 100644 (file)
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                      bool   v_trans,
                      bool   offload,
                      bool   swa_full,
+                     bool   unified,
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
                  uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams) {
+                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
     llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
     llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
 
     kv_base = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_base), type_k, type_v,
-            v_trans, offload, size_base, n_seq_max, n_pad,
+            v_trans, offload, unified, size_base, n_seq_max, n_pad,
             0, LLAMA_SWA_TYPE_NONE);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_swa), type_k, type_v,
-            v_trans, offload, size_swa, n_seq_max, n_pad,
+            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
             hparams.n_swa, hparams.swa_type);
 }
 
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
     // first try simple split
     do {
+        if (!unified) {
+            // requires equal splits, so we skip the simple split
+            break;
+        }
+
         balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch, false);
+            auto ubatch = balloc.split_equal(n_ubatch, !unified);
 
             if (ubatch.n_tokens == 0) {
                 break;
index 23205d826b23b231bb2026f082debe064f092500..d2650dadd3595b2614551e36941bffda8f6909af 100644 (file)
@@ -20,6 +20,7 @@ public:
                          bool   v_trans,
                          bool   offload,
                          bool   swa_full,
+                         bool   unified,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
                      uint32_t   n_ubatch,
@@ -68,6 +69,8 @@ public:
 private:
     const llama_hparams & hparams;
 
+    const bool unified;
+
     std::unique_ptr<llama_kv_cache_unified> kv_base;
     std::unique_ptr<llama_kv_cache_unified> kv_swa;
 };
index d3129cc53281e6589ebd7e7a3cae1ea6407878c4..7e92e6b4df9d4419578f940bafd4c4290f21389c 100644 (file)
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
                 ggml_type    type_v,
                      bool    v_trans,
                      bool    offload,
+                     bool    unified,
                  uint32_t    kv_size,
                  uint32_t    n_seq_max,
                  uint32_t    n_pad,
                  uint32_t    n_swa,
            llama_swa_type    swa_type) :
     model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         return it->second;
     };
 
-    head = 0;
+    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
 
-    cells.resize(kv_size);
+    v_heads.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_heads[s] = 0;
+    }
+
+    v_cells.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].resize(kv_size);
+    }
+
+    // by default, all sequence ids are mapped to the 0th stream
+    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
+
+    if (n_stream > 1) {
+        seq_to_stream.resize(n_stream, 0);
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            seq_to_stream[s] = s;
+        }
+    }
+
+    // [TAG_V_CACHE_VARIABLE]
+    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
+        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
+                __func__, hparams.n_embd_v_gqa_max());
+    }
 
     for (uint32_t il = 0; il < n_layer_cache; il++) {
         if (filter && !filter(il)) {
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
             continue;
         }
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+        // [TAG_V_CACHE_VARIABLE]
+        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
 
         const char * dev_name = "CPU";
 
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         ggml_tensor * k;
         ggml_tensor * v;
 
-        k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
-        v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
+        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
+        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
 
         ggml_format_name(k, "cache_k_l%d", il);
         ggml_format_name(v, "cache_v_l%d", il);
 
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
+            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
+        }
+
         map_layer_ids[il] = layers.size();
-        layers.push_back({ il, k, v });
+
+        layers.push_back({ il, k, v, k_stream, v_stream, });
     }
 
     // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
@@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         const size_t memory_size_k = size_k_bytes();
         const size_t memory_size_v = size_v_bytes();
 
-        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
@@ -160,15 +195,21 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
     supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
 
+    if (!supports_set_rows) {
+        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
+        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
+    }
+
     if (!supports_set_rows) {
         LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
     }
 }
 
 void llama_kv_cache_unified::clear(bool data) {
-    cells.reset();
-
-    head = 0;
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].reset();
+        v_heads[s] = 0;
+    }
 
     if (data) {
         for (auto & buf : bufs) {
@@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
 }
 
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     uint32_t new_head = cells.size();
 
     if (p0 < 0) {
@@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
 }
 
 void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
+    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
+    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
+
+    const auto s0 = seq_to_stream[seq_id_src];
+    const auto s1 = seq_to_stream[seq_id_dst];
+
+    if (s0 == s1) {
+        // since both sequences are in the same stream, no data copy is necessary
+        // we just have to update the cells meta data
+
+        auto & cells = v_cells[s0];
+
+        if (seq_id_src == seq_id_dst) {
+            return;
+        }
+
+        if (p0 < 0) {
+            p0 = 0;
+        }
+
+        if (p1 < 0) {
+            p1 = std::numeric_limits<llama_pos>::max();
+        }
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id_src)) {
+                cells.seq_add(i, seq_id_dst);
+            }
+        }
+
         return;
     }
 
-    if (p0 < 0) {
-        p0 = 0;
+    // cross-stream sequence copies require to copy the actual buffer data
+
+    bool is_full = true;
+
+    if (p0 > 0 && p0 + 1 < (int) get_size()) {
+        is_full = false;
     }
 
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
+    if (p1 > 0 && p1 + 1 < (int) get_size()) {
+        is_full = false;
     }
 
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
+    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
+
+    // enqueue the copy operation - the buffer copy will be performed during the next update
+    sc_info.ssrc.push_back(s0);
+    sc_info.sdst.push_back(s1);
+
+    v_cells[s1].reset();
+    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
+        if (v_cells[s0].seq_has(i, seq_id_src)) {
+            llama_pos pos   = v_cells[s0].pos_get(i);
+            llama_pos shift = v_cells[s0].get_shift(i);
+
+            if (shift != 0) {
+                pos -= shift;
+                assert(pos >= 0);
+            }
+
+            v_cells[s1].pos_set(i, pos);
+            v_cells[s1].seq_add(i, seq_id_dst);
 
-        if (cells.seq_has(i, seq_id_src)) {
-            cells.seq_add(i, seq_id_dst);
+            if (shift != 0) {
+                v_cells[s1].pos_add(i, shift);
+            }
         }
     }
+
+    v_heads[s1] = v_heads[s0];
+
+    //for (uint32_t s = 0; s < n_stream; ++s) {
+    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
+    //}
 }
 
 void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     uint32_t new_head = cells.size();
 
     for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
 }
 
 void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     if (shift == 0) {
         return;
     }
@@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
 }
 
 void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+
     if (d == 1) {
         return;
     }
@@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
     return cells.seq_pos_min(seq_id);
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
     return cells.seq_pos_max(seq_id);
 }
 
@@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_simple(n_ubatch);
+            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
 
             if (ubatch.n_tokens == 0) {
                 break;
@@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
     defrag_info dinfo;
 
     // see if we need to defrag
-    {
+    if (n_stream == 1) {
+        // note : for now do not consider defrag for n_stream > 1
+        const auto & cells = v_cells[seq_to_stream[0]];
+
         bool do_defrag = optimize;
 
         const auto thold = lctx->get_cparams().defrag_thold;
@@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
         }
     }
 
-    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
+    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
 }
 
 llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
     llama_kv_cache_unified::slot_info_vec_t res;
 
-    struct state {
-        uint32_t head_old; // old position of the head, before placing the ubatch
-
+    struct state_t {
         slot_info sinfo; // slot info for the ubatch
 
-        llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
+        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
+
+        std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
     };
 
     // remember the old state of the cells so we can restore it in the end
-    std::vector<state> states;
+    std::vector<state_t> states;
 
     bool success = true;
 
@@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
         res.push_back(sinfo_new);
 
         // store the old state of the cells in the recovery stack
-        states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
+        {
+            state_t state = { sinfo_new, v_heads, {} };
+
+            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
+                auto & cells = v_cells[sinfo_new.strm[s]];
+
+                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
+            }
+
+            states.push_back(std::move(state));
+        }
 
         // now emplace the ubatch
         apply_ubatch(sinfo_new, ubatch);
     }
 
+    GGML_ASSERT(!states.empty() || !success);
+
     // iterate backwards and restore the cells to their original state
     for (auto it = states.rbegin(); it != states.rend(); ++it) {
-        cells.set(it->sinfo.idxs, it->cells);
-        head = it->head_old;
+        const auto & sinfo = it->sinfo;
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            auto & cells = v_cells[sinfo.strm[s]];
+            auto & head  = v_heads[sinfo.strm[s]];
+
+            cells.set(sinfo.idxs[s], it->v_cells[s]);
+            head = it->v_heads_old[s];
+        }
     }
 
     if (!success) {
@@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
     return res;
 }
 
-bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
     bool updated = false;
 
     auto * sched = lctx->get_sched();
 
+    if (!sc_info.empty()) {
+        assert(n_stream > 1 && "stream copy should never happen with a single stream");
+
+        llama_synchronize(lctx);
+
+        const size_t n_copy = sc_info.ssrc.size();
+
+        for (size_t i = 0; i < n_copy; ++i) {
+            const auto ssrc = sc_info.ssrc[i];
+            const auto sdst = sc_info.sdst[i];
+
+            assert(ssrc < n_stream);
+            assert(sdst < n_stream);
+
+            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
+
+            assert(ssrc != sdst);
+
+            for (uint32_t il = 0; il < layers.size(); ++il) {
+                const auto & layer = layers[il];
+
+                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
+                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+            }
+        }
+    }
+
     if (do_shift) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -503,12 +679,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
             updated = true;
         }
 
-        cells.reset_shift();
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+
+            cells.reset_shift();
+        }
     }
 
     if (!dinfo.empty()) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
+        // note: for now do not consider defrag for n_stream > 1
+        auto & cells = v_cells[seq_to_stream[0]];
+        auto & head  = v_heads[seq_to_stream[0]];
+
         // apply moves:
         {
             const auto n_kv = dinfo.ids.size();
@@ -556,23 +740,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
 }
 
 llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
-
-    uint32_t head_cur = this->head;
-
-    // if we have enough unused cells before the current head ->
-    //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
-        head_cur = 0;
-    }
+    if (debug > 0) {
+        const auto & cells = v_cells[seq_to_stream[1]];
 
-    if (n_tokens > cells.size()) {
-        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-        return { };
-    }
+        const uint32_t head_cur = v_heads[1];
 
-    if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
+        LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
+                __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
 
         if ((debug == 2 && n_swa > 0) || debug > 2) {
             std::string ss;
@@ -629,86 +803,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
         }
     }
 
-    uint32_t n_tested = 0;
+    uint32_t n_tokens = ubatch.n_tokens;
+    uint32_t n_seqs   = 1;
 
-    // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
-    // for non-continuous slots, we test the tokens one by one
-    const uint32_t n_test = cont ? n_tokens : 1;
+    if (n_stream > 1) {
+        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
 
-    slot_info res;
+        n_seqs   = ubatch.n_seqs_unq;
+        n_tokens = n_tokens / n_seqs;
+    }
+
+    slot_info res = {
+        /*.s0   =*/ LLAMA_MAX_SEQ,
+        /*.s1   =*/ 0,
+        /*.strm =*/ { },
+        /*.idxs =*/ { },
+    };
+
+    res.resize(n_seqs);
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const auto seq_id = ubatch.seq_id_unq[s];
+
+        if (n_stream > 1) {
+            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
+            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
+        }
+
+        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+
+        res.strm[s] = seq_to_stream[seq_id];
+        res.idxs[s].reserve(n_tokens);
 
-    auto & idxs = res.idxs;
+        const auto & cells = v_cells[seq_to_stream[seq_id]];
 
-    idxs.reserve(n_tokens);
+        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
 
-    while (true) {
-        if (head_cur + n_test > cells.size()) {
-            n_tested += cells.size() - head_cur;
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (head_cur > cells.get_used() + 2*n_tokens) {
             head_cur = 0;
-            continue;
         }
 
-        for (uint32_t i = 0; i < n_test; i++) {
-            const auto idx = head_cur;
+        if (n_tokens > cells.size()) {
+            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
+            return { };
+        }
+
+        uint32_t n_tested = 0;
+
+        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
+        // for non-continuous slots, we test the tokens one by one
+        const uint32_t n_test = cont ? n_tokens : 1;
+
+        while (true) {
+            if (head_cur + n_test > cells.size()) {
+                n_tested += cells.size() - head_cur;
+                head_cur = 0;
+                continue;
+            }
+
+            for (uint32_t i = 0; i < n_test; i++) {
+                const auto idx = head_cur;
+
+                head_cur++;
+                n_tested++;
 
-            //const llama_pos    pos    = ubatch.pos[i];
-            //const llama_seq_id seq_id = ubatch.seq_id[i][0];
+                //const llama_pos    pos    = ubatch.pos[i];
+                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
-            // can we use this cell? either:
-            //  - the cell is empty
-            //  - the cell is occupied only by one sequence:
-            //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
-            //    - mask SWA, using current max pos for that sequence in the cache
-            //                always insert in the cell with minimum pos
-            bool can_use = cells.is_empty(idx);
+                // can we use this cell? either:
+                //  - the cell is empty
+                //  - the cell is occupied only by one sequence:
+                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
+                //    - mask SWA, using current max pos for that sequence in the cache
+                //                always insert in the cell with minimum pos
+                bool can_use = cells.is_empty(idx);
 
-            if (!can_use && cells.seq_count(idx) == 1) {
-                const llama_pos pos_cell = cells.pos_get(idx);
+                if (!can_use && cells.seq_count(idx) == 1) {
+                    const llama_pos pos_cell = cells.pos_get(idx);
 
-                // (disabled) causal mask
-                // note: it's better to purge any "future" tokens beforehand
-                //if (cells.seq_has(idx, seq_id)) {
-                //    can_use = pos_cell >= pos;
-                //}
+                    // (disabled) causal mask
+                    // note: it's better to purge any "future" tokens beforehand
+                    //if (cells.seq_has(idx, seq_id)) {
+                    //    can_use = pos_cell >= pos;
+                    //}
 
-                if (!can_use) {
-                    const llama_seq_id seq_id_cell = cells.seq_get(idx);
+                    if (!can_use) {
+                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
 
-                    // SWA mask
-                    if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                        can_use = true;
+                        // SWA mask
+                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                            can_use = true;
+                        }
                     }
                 }
-            }
 
-            head_cur++;
-            n_tested++;
+                if (can_use) {
+                    res.idxs[s].push_back(idx);
+                } else {
+                    if (cont) {
+                        break;
+                    }
+                }
+            }
 
-            if (can_use) {
-                idxs.push_back(idx);
-            } else {
+            if (res.idxs[s].size() == n_tokens) {
                 break;
             }
-        }
 
-        if (idxs.size() == n_tokens) {
-            break;
-        }
+            if (cont) {
+                res.idxs[s].clear();
+            }
 
-        if (cont) {
-            idxs.clear();
+            if (n_tested >= cells.size()) {
+                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+                return { };
+            }
         }
 
-        if (n_tested >= cells.size()) {
-            //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+        // we didn't find a suitable slot - return empty result
+        if (res.idxs[s].size() < n_tokens) {
             return { };
         }
     }
 
-    // we didn't find a suitable slot - return empty result
-    if (idxs.size() < n_tokens) {
-        res.clear();
-    }
+    assert(res.s1 >= res.s0);
 
     return res;
 }
@@ -717,41 +938,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
-    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
         seq_pos_max_rm[s] = -1;
     }
 
-    assert(ubatch.n_tokens == sinfo.idxs.size());
+    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
 
-    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
-        const auto idx = sinfo.idxs.at(i);
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
+            const uint32_t i = s*sinfo.size() + ii;
 
-        if (!cells.is_empty(idx)) {
-            assert(cells.seq_count(idx) == 1);
+            auto & cells = v_cells[sinfo.strm[s]];
 
-            const llama_seq_id seq_id = cells.seq_get(idx);
-            const llama_pos    pos    = cells.pos_get(idx);
+            const auto idx = sinfo.idxs[s][ii];
 
-            seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+            if (!cells.is_empty(idx)) {
+                assert(cells.seq_count(idx) == 1);
 
-            cells.rm(idx);
-        }
+                const llama_seq_id seq_id = cells.seq_get(idx);
+                const llama_pos    pos    = cells.pos_get(idx);
 
-        cells.pos_set(idx, ubatch.pos[i]);
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 
-        for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
-            cells.seq_add(idx, ubatch.seq_id[i][s]);
+                cells.rm(idx);
+            }
+
+            cells.pos_set(idx, ubatch.pos[i]);
+
+            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
+                cells.seq_add(idx, ubatch.seq_id[i][s]);
+            }
         }
     }
 
     // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
     //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
     //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
         if (seq_pos_max_rm[s] == -1) {
             continue;
         }
 
+        GGML_ASSERT(s < seq_to_stream.size());
+
+        auto & cells = v_cells[seq_to_stream[s]];
+
         if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
             LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
                     __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@@ -761,7 +992,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
     }
 
     // move the head at the end of the slot
-    head = sinfo.idxs.back() + 1;
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        auto & head = v_heads[sinfo.strm[s]];
+
+        head = sinfo.idxs[s].back() + 1;
+    }
 }
 
 bool llama_kv_cache_unified::get_can_shift() const {
@@ -769,49 +1004,87 @@ bool llama_kv_cache_unified::get_can_shift() const {
 }
 
 uint32_t llama_kv_cache_unified::get_size() const {
+    const auto & cells = v_cells[seq_to_stream[0]];
+
     return cells.size();
 }
 
+uint32_t llama_kv_cache_unified::get_n_stream() const {
+    return n_stream;
+}
+
 bool llama_kv_cache_unified::get_has_shift() const {
-    return cells.get_has_shift();
+    bool result = false;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        result |= v_cells[s].get_has_shift();
+    }
+
+    return result;
 }
 
 uint32_t llama_kv_cache_unified::get_n_kv() const {
-    return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
+    uint32_t result = 0;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
+    }
+
+    return result;
 }
 
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
+ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
 
-    return ggml_view_3d(ctx, k,
-            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_k_gqa = k->ne[0];
+
+    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    return ggml_view_4d(ctx, k,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
             ggml_row_size(k->type, hparams.n_embd_head_k),
-            ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
-            0);
+            ggml_row_size(k->type, n_embd_k_gqa),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
 }
 
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
+ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
 
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_v_gqa = v->ne[0];
+
+    // [TAG_V_CACHE_VARIABLE]
+    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
     if (!v_trans) {
         // note: v->nb[1] <= v->nb[2]
-        return ggml_view_3d(ctx, v,
-                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
-                ggml_row_size(v->type, hparams.n_embd_head_v),    // v->nb[1]
-                ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
-                0);
+        return ggml_view_4d(ctx, v,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
+                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
     }
 
     // note: v->nb[1] > v->nb[2]
-    return ggml_view_3d(ctx, v,
-            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
-            ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
-            ggml_row_size(v->type, v->ne[1]),                       // v->nb[2]
-            0);
+    return ggml_view_4d(ctx, v,
+            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
+            ggml_row_size(v->type, kv_size),                          // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
 }
 
 ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,12 +1098,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
     k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
 
     if (k_idxs && supports_set_rows) {
+        if (k->ne[2] > 1) {
+            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
+        }
+
         return ggml_set_rows(ctx, k, k_cur, k_idxs);
     }
 
     // TODO: fallback to old ggml_cpy() method for backwards compatibility
     //       will be removed when ggml_set_rows() is adopted by all backends
 
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
     ggml_tensor * k_view = ggml_view_1d(ctx, k,
             n_tokens*n_embd_k_gqa,
             ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
@@ -843,37 +1122,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 
     auto * v = layers[ikv].v;
 
-    const int64_t n_embd_v_gqa = v->ne[0];
-    const int64_t n_tokens = v_cur->ne[2];
+    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
+    const int64_t n_tokens     = v_cur->ne[2];
 
     v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
 
     if (v_idxs && supports_set_rows) {
         if (!v_trans) {
+            if (v->ne[2] > 1) {
+                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
+            }
+
             return ggml_set_rows(ctx, v, v_cur, v_idxs);
         }
 
-        // the row becomes a single element
-        ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
+        // [TAG_V_CACHE_VARIABLE]
+        if (n_embd_v_gqa < v->ne[0]) {
+            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+        }
 
-        // note: the V cache is transposed when not using flash attention
-        v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
+        // the row becomes a single element
+        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
 
-        // note: we can be more explicit here at the cost of extra cont
-        //       however, above we take advantage that a row of single element is always continuous regardless of the row stride
-        //v_cur = ggml_transpose(ctx, v_cur);
-        //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
+        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
 
-        // we broadcast the KV indices n_embd_v_gqa times
-        // v      [1,        n_kv,     n_embd_v_gqa]
-        // v_cur  [1,        n_tokens, n_embd_v_gqa]
-        // v_idxs [n_tokens, 1,        1]
         return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
     }
 
     // TODO: fallback to old ggml_cpy() method for backwards compatibility
     //       will be removed when ggml_set_rows() is adopted by all backends
 
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
     ggml_tensor * v_view = nullptr;
 
     if (!v_trans) {
@@ -904,7 +1184,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
 ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
-    ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    ggml_tensor * v_idxs;
+
+    if (!v_trans) {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    } else {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
+    }
 
     ggml_set_input(v_idxs);
 
@@ -917,12 +1203,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
     }
 
     const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     int64_t * data = (int64_t *) dst->data;
 
-    for (int64_t i = 0; i < n_tokens; ++i) {
-        data[i] = sinfo.idxs.at(i);
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const int64_t offs = sinfo.strm[s]*get_size();
+
+        for (uint32_t i = 0; i < sinfo.size(); ++i) {
+            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+        }
     }
 }
 
@@ -932,12 +1223,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
     }
 
     const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     int64_t * data = (int64_t *) dst->data;
 
-    for (int64_t i = 0; i < n_tokens; ++i) {
-        data[i] = sinfo.idxs.at(i);
+    if (!v_trans) {
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*get_size();
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+            }
+        }
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        const int64_t kv_size = get_size();
+
+        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    int32_t * data = (int32_t *) dst->data;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
+        }
     }
 }
 
@@ -947,7 +1274,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
 
-    const int64_t n_kv = dst->ne[0];
+    const int64_t n_kv     = dst->ne[0];
+    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+
+    GGML_ASSERT(n_tokens%n_stream == 0);
+
+    // n_tps == n_tokens_per_stream
+    const int64_t n_tps     = n_tokens/n_stream;
+    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
 
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -962,67 +1296,66 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
     for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = ubatch->seq_id[i][0];
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            for (uint32_t ii = 0; ii < n_tps; ++ii) {
+                const uint32_t i = s*n_tps + ii;
 
-            const llama_pos p1 = ubatch->pos[i];
+                const llama_seq_id seq_id = ubatch->seq_id[i][0];
 
-            for (uint32_t j = 0; j < n_kv; ++j) {
-                float f = 0.0f;
+                const auto & cells = v_cells[seq_to_stream[seq_id]];
 
-                bool masked = false;
+                const llama_pos p1 = ubatch->pos[i];
 
-                if (cells.is_empty(j)) {
-                    masked = true;
-                } else {
-                    const llama_pos p0 = cells.pos_get(j);
+                for (uint32_t j = 0; j < n_kv; ++j) {
+                    float f = 0.0f;
 
-                    // mask the token if not the same sequence
-                    masked = masked || (!cells.seq_has(j, seq_id));
+                    bool masked = false;
 
-                    // mask future tokens
-                    masked = masked || (causal_attn && p0 > p1);
+                    if (cells.is_empty(j)) {
+                        masked = true;
+                    } else {
+                        const llama_pos p0 = cells.pos_get(j);
+
+                        // mask the token if not the same sequence
+                        masked = masked || (!cells.seq_has(j, seq_id));
+
+                        // mask future tokens
+                        masked = masked || (causal_attn && p0 > p1);
 
-                    // apply SWA if any
-                    masked = masked || (is_masked_swa(p0, p1));
+                        // apply SWA if any
+                        masked = masked || (is_masked_swa(p0, p1));
 
-                    if (!masked && hparams.use_alibi) {
-                        f = -std::abs(p0 - p1);
+                        if (!masked && hparams.use_alibi) {
+                            f = -std::abs(p0 - p1);
+                        }
                     }
-                }
 
-                if (masked) {
-                    f = -INFINITY;
-                }
+                    if (masked) {
+                        f = -INFINITY;
+                    }
 
-                data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
-            }
-        }
+                    data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
+                }
 
-        // mask padded tokens
-        if (data) {
-            for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                // mask padded tokens
+                if (data) {
+                    for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
+                        for (uint32_t j = 0; j < n_kv; ++j) {
+                            data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
+                        }
+                    }
                 }
             }
         }
     }
 }
 
-void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-
-    int32_t * data = (int32_t *) dst->data;
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
-    }
-}
-
 void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     const int64_t n_tokens = ubatch->n_tokens;
 
+    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
+    const auto & cells = v_cells[0];
+
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
 
@@ -1129,7 +1462,7 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    ggml_tensor * k_shift; // I32 [kv_size]
+    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
 
     const llama_kv_cache_unified * kv_self;
 };
@@ -1153,7 +1486,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
     ggml_set_input(inp->k_shift);
 
     for (const auto & layer : layers) {
@@ -1169,7 +1502,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, cells.size(),
+                n_embd_head_k, n_head_kv, get_size()*n_stream,
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
                 0);
@@ -1191,6 +1524,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
                   const defrag_info & dinfo) const {
     auto res = std::make_unique<llm_graph_result>();
 
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
+
     const auto & ids = dinfo.ids;
 
 #if 0
@@ -1333,6 +1670,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
 }
 
 llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
+
     const uint32_t n_layer = layers.size();
 
     const uint32_t n_kv   = cells.used_max_p1();
@@ -1478,64 +1819,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
 }
 
 void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
-    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
+    io.write(&n_stream, sizeof(n_stream));
 
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id (or all, when -1)
-    uint32_t cell_range_begin = cells.size();
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        cell_ranges_t cr { s, {} };
 
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
-            ++cell_count;
-            if (cell_range_begin == cells.size()) {
-                cell_range_begin = i;
-            }
-        } else {
-            if (cell_range_begin != cells.size()) {
-                cell_ranges.emplace_back(cell_range_begin, i);
-                cell_range_begin = cells.size();
+        uint32_t cell_count = 0;
+
+        const auto & cells = v_cells[s];
+
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
+                ++cell_count;
+                if (cell_range_begin == cells.size()) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != cells.size()) {
+                    cr.data.emplace_back(cell_range_begin, i);
+                    cell_range_begin = cells.size();
+                }
             }
         }
-    }
 
-    if (cell_range_begin != cells.size()) {
-        cell_ranges.emplace_back(cell_range_begin, cells.size());
-    }
+        if (cell_range_begin != cells.size()) {
+            cr.data.emplace_back(cell_range_begin, cells.size());
+        }
 
-    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-    uint32_t cell_count_check = 0;
-    for (const auto & range : cell_ranges) {
-        cell_count_check += range.second - range.first;
-    }
-    GGML_ASSERT(cell_count == cell_count_check);
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cr.data) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
 
-    io.write(&cell_count, sizeof(cell_count));
+        io.write(&cell_count, sizeof(cell_count));
 
-    state_write_meta(io, cell_ranges, seq_id);
-    state_write_data(io, cell_ranges);
+        // skip empty streams
+        if (cell_count == 0) {
+            continue;
+        }
+
+        state_write_meta(io, cr, seq_id);
+        state_write_data(io, cr);
+    }
 }
 
 void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
-    uint32_t cell_count;
-    io.read_to(&cell_count, sizeof(cell_count));
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
-    bool res = true;
-    res = res && state_read_meta(io, cell_count, seq_id);
-    res = res && state_read_data(io, cell_count);
+    uint32_t n_stream_cur;
+    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
+    if (n_stream_cur != n_stream) {
+        throw std::runtime_error("n_stream mismatch");
+    }
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        uint32_t cell_count;
+        io.read_to(&cell_count, sizeof(cell_count));
+
+        if (cell_count == 0) {
+            continue;
+        }
+
+        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
 
-    if (!res) {
-        if (seq_id == -1) {
-            clear(true);
-        } else {
-            seq_rm(seq_id, -1, -1);
+        bool res = true;
+        res = res && state_read_meta(io, strm, cell_count, seq_id);
+        res = res && state_read_data(io, strm, cell_count);
+
+        if (!res) {
+            if (seq_id == -1) {
+                clear(true);
+            } else {
+                seq_rm(seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
         }
-        throw std::runtime_error("failed to restore kv cache");
     }
 }
 
-void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
-    for (const auto & range : cell_ranges) {
+void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
+    const auto & cells = v_cells[cr.strm];
+
+    for (const auto & range : cr.data) {
         for (uint32_t i = range.first; i < range.second; ++i) {
             std::vector<llama_seq_id> seq_ids;
 
@@ -1560,7 +1931,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
     }
 }
 
-void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
+    const auto & cells = v_cells[cr.strm];
+
     const uint32_t v_trans = this->v_trans ? 1 : 0;
     const uint32_t n_layer = layers.size();
 
@@ -1576,19 +1949,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
+        auto * k = layer.k_stream[cr.strm];
+
         // Write key type
-        const int32_t k_type_i = (int32_t)layer.k->type;
+        const int32_t k_type_i = (int32_t) k->type;
         io.write(&k_type_i, sizeof(k_type_i));
 
         // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
+        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
         io.write(&k_size_row, sizeof(k_size_row));
 
         // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
+        for (const auto & range : cr.data) {
             const size_t range_size = range.second - range.first;
             const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(layer.k, range.first * k_size_row, buf_size);
+            io.write_tensor(k, range.first * k_size_row, buf_size);
         }
     }
 
@@ -1598,19 +1973,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[cr.strm];
+
             // Write value type
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
+            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
             io.write(&v_size_row, sizeof(v_size_row));
 
             // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
+            for (const auto & range : cr.data) {
                 const size_t range_size = range.second - range.first;
                 const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(layer.v, range.first * v_size_row, buf_size);
+                io.write_tensor(v, range.first * v_size_row, buf_size);
             }
         }
     } else {
@@ -1622,12 +1999,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[cr.strm];
+
             // Write value type
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write element size
-            const uint32_t v_size_el = ggml_type_size(layer.v->type);
+            const uint32_t v_size_el = ggml_type_size(v->type);
             io.write(&v_size_el, sizeof(v_size_el));
 
             // Write GQA embedding size
@@ -1636,27 +2015,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
             // For each row, we get the element values of each cell
             for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
                 // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
+                for (const auto & range : cr.data) {
                     const size_t range_size = range.second - range.first;
                     const size_t src_offset = (range.first + j * kv_size) * v_size_el;
                     const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(layer.v, src_offset, buf_size);
+                    io.write_tensor(v, src_offset, buf_size);
                 }
             }
         }
     }
 }
 
-bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
     if (dest_seq_id != -1) {
         // single sequence
-
         seq_rm(dest_seq_id, -1, -1);
 
         llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
+        ubatch.seq_id_unq[0] = dest_seq_id;
+
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
             uint32_t n_seq_id;
@@ -1693,6 +2076,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         // keep the head at the old position because we will read the KV data into it in state_read_data()
         head = head_cur;
 
+        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+
         // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
         GGML_ASSERT(head_cur + cell_count <= cells.size());
@@ -1738,7 +2123,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
     return true;
 }
 
-bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
     uint32_t v_trans;
     uint32_t n_layer;
 
@@ -1766,10 +2154,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
+        auto * k = layer.k_stream[strm];
+
         // Read type of key
         int32_t k_type_i_ref;
         io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) layer.k->type;
+        const int32_t k_type_i = (int32_t) k->type;
         if (k_type_i != k_type_i_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
             return false;
@@ -1778,7 +2168,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         // Read row size of key
         uint64_t k_size_row_ref;
         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
+        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
         if (k_size_row != k_size_row_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
             return false;
@@ -1786,7 +2176,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
         if (cell_count) {
             // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
         }
     }
 
@@ -1796,10 +2186,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[strm];
+
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1808,7 +2200,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read row size of value
             uint64_t v_size_row_ref;
             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
+            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
             if (v_size_row != v_size_row_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
                 return false;
@@ -1816,7 +2208,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             if (cell_count) {
                 // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
             }
         }
     } else {
@@ -1826,10 +2218,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[strm];
+
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1838,7 +2232,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read element size of value
             uint32_t v_size_el_ref;
             io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(layer.v->type);
+            const size_t v_size_el = ggml_type_size(v->type);
             if (v_size_el != v_size_el_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
                 return false;
@@ -1856,7 +2250,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
                 // For each row in the transposed matrix, read the values for the whole cell range
                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
                     const size_t dst_offset = (head + j * cells.size()) * v_size_el;
-                    ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
                 }
             }
         }
@@ -1875,18 +2269,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
 
+    const uint32_t n_stream = kv->get_n_stream();
+
     // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
     sinfos.resize(1);
-    sinfos[0].idxs.resize(1);
-    sinfos[0].idxs[0] = 0;
+    sinfos[0].s0 = 0;
+    sinfos[0].s1 = n_stream - 1;
+    sinfos[0].idxs.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        sinfos[0].strm.push_back(s);
+        sinfos[0].idxs[s].resize(1, 0);
+    }
 }
 
 llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_context * lctx,
         bool do_shift,
-        defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
-    if (!do_shift && this->dinfo.empty()) {
+        defrag_info dinfo,
+        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
+    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
         status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
 }
@@ -1914,7 +2316,7 @@ bool llama_kv_cache_unified_context::apply() {
 
     // no ubatches -> this is a KV cache update
     if (ubatches.empty()) {
-        kv->update(lctx, do_shift, dinfo);
+        kv->update(lctx, do_shift, dinfo, sc_info);
 
         return true;
     }
@@ -1941,11 +2343,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
 }
 
 ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
-    return kv->get_k(ctx, il, n_kv);
+    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
 
 ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
-    return kv->get_v(ctx, il, n_kv);
+    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
 }
 
 ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
index b8b0356e830c89d84483e6afc474f6cf7492e1f2..3bfda4600d8432bac2d5ffbac53bf8a4aea9f863 100644 (file)
@@ -35,16 +35,50 @@ public:
         std::vector<uint32_t> ids;
     };
 
+    struct stream_copy_info {
+        bool empty() const {
+            assert(ssrc.size() == sdst.size());
+            return ssrc.empty();
+        }
+
+        std::vector<uint32_t> ssrc;
+        std::vector<uint32_t> sdst;
+    };
+
     // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
     //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
     struct slot_info {
         // data for ggml_set_rows
         using idx_vec_t = std::vector<uint32_t>;
 
-        idx_vec_t idxs;
+        // number of streams: ns = s1 - s0 + 1
+        llama_seq_id s0;
+        llama_seq_id s1;
+
+        std::vector<llama_seq_id> strm; // [ns]
+        std::vector<idx_vec_t>    idxs; // [ns]
 
         uint32_t head() const {
-            return idxs.at(0);
+            GGML_ASSERT(idxs.size() == 1);
+            GGML_ASSERT(!idxs[0].empty());
+
+            return idxs[0][0];
+        }
+
+        void resize(size_t n) {
+            strm.resize(n);
+            idxs.resize(n);
+        }
+
+        size_t size() const {
+            GGML_ASSERT(idxs.size() == strm.size());
+            GGML_ASSERT(!idxs.empty());
+
+            return idxs[0].size();
+        }
+
+        size_t n_stream() const {
+            return strm.size();
         }
 
         bool empty() const {
@@ -54,9 +88,6 @@ public:
         void clear() {
             idxs.clear();
         }
-
-        // TODO: implement
-        //std::vector<idx_vec_t> seq_idxs;
     };
 
     using slot_info_vec_t = std::vector<slot_info>;
@@ -68,6 +99,7 @@ public:
                     ggml_type    type_v,
                          bool    v_trans,
                          bool    offload,
+                         bool    unified,
                      uint32_t    kv_size,
                      uint32_t    n_seq_max,
                      uint32_t    n_pad,
@@ -111,7 +143,8 @@ public:
     // llama_kv_cache_unified specific API
     //
 
-    uint32_t get_size() const;
+    uint32_t get_size()     const;
+    uint32_t get_n_stream() const;
 
     bool get_has_shift() const;
 
@@ -122,8 +155,8 @@ public:
     uint32_t get_n_kv() const;
 
     // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
     ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@@ -137,7 +170,7 @@ public:
     // return empty vector on failure
     slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
 
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
 
     // find a slot of kv cells that can hold the ubatch
     // if cont == true, then the slot must be continuous
@@ -157,8 +190,9 @@ public:
     void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
     void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
 
+    void set_input_k_shift(ggml_tensor * dst) const;
+
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_k_shift   (ggml_tensor * dst) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
@@ -172,15 +206,15 @@ private:
 
         ggml_tensor * k;
         ggml_tensor * v;
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
     };
 
     bool v_trans = true;  // the value tensor is transposed
 
-    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
-    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
-    uint32_t head = 0;
-
     const uint32_t n_seq_max = 1;
+    const uint32_t n_stream  = 1;
 
     // required padding
     const uint32_t n_pad = 1;
@@ -200,7 +234,17 @@ private:
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    llama_kv_cells_unified cells;
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    std::vector<uint32_t> v_heads;
+
+    std::vector<llama_kv_cells_unified> v_cells;
+
+    // maps from a sequence id to a stream id
+    std::vector<uint32_t> seq_to_stream;
+
+    // pending stream copies that will be applied during the next update
+    stream_copy_info sc_info;
 
     std::vector<kv_layer> layers;
 
@@ -237,18 +281,25 @@ private:
                     ggml_cgraph * gf,
               const defrag_info & dinfo) const;
 
-    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+    struct cell_ranges_t {
+        uint32_t strm;
 
-    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
+    };
+
+    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
 };
 
 class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info     = llama_kv_cache_unified::defrag_info;
+    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
+    using defrag_info      = llama_kv_cache_unified::defrag_info;
+    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
 
     // used for errors
     llama_kv_cache_unified_context(llama_memory_status status);
@@ -262,7 +313,8 @@ public:
             llama_kv_cache_unified * kv,
             llama_context * lctx,
             bool do_shift,
-            defrag_info dinfo);
+            defrag_info dinfo,
+            stream_copy_info sc_info);
 
     // used to create a batch procesing context from a batch
     llama_kv_cache_unified_context(
@@ -320,6 +372,8 @@ private:
 
     defrag_info dinfo;
 
+    stream_copy_info sc_info;
+
     //
     // batch processing context
     //
index 6cd10db06b77571675aa2a6dfa6ab8977552d199..eedfaec53e8760de9e85b0e7e3a917f90593276a 100644 (file)
@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         offload,
         kv_size,
         n_seq_max,
+        1,
         n_pad,
         n_swa,
         swa_type
index 82ddc5cef67651530bd9776f55a3b9d3935cb604..67cae69579fdb50dc57d9ebac1e6b9f126b8a379 100644 (file)
@@ -16647,7 +16647,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 } else {
                     const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
-                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+                    uint32_t n_ctx_per_stream = cparams.n_ctx;
+
+                    if (!cparams.kv_unified) {
+                        n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
+                    } else {
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream;
+                    }
 
                     LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
@@ -16661,7 +16672,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
                                 params.swa_full,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
                                 padding);
@@ -16675,7 +16687,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 params.type_v,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
index 81fe90b99323d7768324ba2a13777d1733fcd408..a3d68fba046cf5e207e45e12662f518c61be5a5d 100644 (file)
@@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case {
 
         ggml_tensor * m = nullptr;
         if (mask) {
-            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
             ggml_set_name(m, "m");
         }
 
index a0a2e5ac56ea94eb88bf552a42570ae8e76ec413..03628f74b2880ec9c7266cd85d60b9adcbcb1cda 100644 (file)
@@ -127,10 +127,9 @@ int main(int argc, char ** argv) {
 
                 for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
                     for (int i = 0; i < pp; ++i) {
-                        common_batch_add(batch, 0, i, { j }, false);
+                        common_batch_add(batch, 0, i, { j }, i == pp - 1);
                     }
                 }
-                batch.logits[batch.n_tokens - 1] = true;
 
                 const auto t_pp_start = ggml_time_us();