]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : refactor + add llama_memory_state_i (#13746)
authorGeorgi Gerganov <redacted>
Sat, 31 May 2025 07:24:04 +0000 (10:24 +0300)
committerGitHub <redacted>
Sat, 31 May 2025 07:24:04 +0000 (10:24 +0300)
* kv-cache : simplify the "struct llama_kv_cache" interface

ggml-ci

* kv-cache : revert the (n_swa + n_ubatch) change (for next PR)

ggml-ci

* kv-cache : some comments

ggml-ci

* context : fix graph reserve for multiple sequences

ggml-ci

* kv-cache : fix typo [no ci]

* kv-cache : fix find_slot() logic for free slots

ggml-ci

* llama : add TODO for deprecating the defrag API in the future

* kv-cache : improve find_slot() using min/max seq pos info

ggml-ci

* llama : handle aborts and compute errors

ggml-ci

* memory : extract state into llama_memory_state

ggml-ci

* kv-cache : add comments

ggml-ci

* server : update batching logic to reset n_batch on successful decode

* server : upon full re-processing, remove the sequence from the cache

* kv-cache : add TODO for doing split_equal when split_simple fails

ggml-ci

14 files changed:
examples/parallel/parallel.cpp
include/llama.h
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-kv-cells.h
src/llama-memory.h
src/llama-model.cpp
tools/server/server.cpp

index 931ea0035cffbd8b2f9d5b1c3b1d9bb37f613501..22118faf8c20d73ab7198e6c25eb51b52e423eac 100644 (file)
@@ -362,7 +362,9 @@ int main(int argc, char ** argv) {
         // process in chunks of params.n_batch
         int32_t n_batch = params.n_batch;
 
-        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
+        int32_t i_next = 0;
+
+        for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
             // experiment: process in powers of 2
             //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
             //    n_batch /= 2;
@@ -370,7 +372,7 @@ int main(int argc, char ** argv) {
             //    continue;
             //}
 
-            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+            const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 
             llama_batch batch_view = {
                 n_tokens,
@@ -396,13 +398,18 @@ int main(int argc, char ** argv) {
 
                 // retry with half the batch size to try to find a free slot in the KV cache
                 n_batch /= 2;
-                i -= n_batch;
 
                 continue;
             }
 
             LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
 
+            // move the head of the batch forward with the number of tokens we just processed
+            i_next = i + n_tokens;
+
+            // on successful decode, restore the original batch size
+            n_batch = params.n_batch;
+
             for (auto & client : clients) {
                 if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
                     continue;
index 01762bea2bf962c33d30c4b72da6c05ce4bbf26f..adc4c69288a3d1e134b75e092f59077e43aa4713 100644 (file)
@@ -259,9 +259,9 @@ extern "C" {
         llama_token  *  token;
         float        *  embd;
         llama_pos    *  pos;
-        int32_t      *  n_seq_id;
-        llama_seq_id ** seq_id;
-        int8_t       *  logits; // TODO: rename this to "output"
+        int32_t      *  n_seq_id; // TODO: remove, should belong to only 1 sequence
+        llama_seq_id ** seq_id;   // TODO: become llama_seq_id * seq_id;
+        int8_t       *  logits;   // TODO: rename this to "output"
     } llama_batch;
 
     enum llama_model_kv_override_type {
@@ -677,12 +677,14 @@ extern "C" {
 
     // Returns the smallest position present in the KV cache for the specified sequence
     // This is typically non-zero only for SWA caches
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
     // Return -1 if the sequence is empty
     LLAMA_API llama_pos llama_kv_self_seq_pos_min(
             struct llama_context * ctx,
                     llama_seq_id   seq_id);
 
     // Returns the largest position present in the KV cache for the specified sequence
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
     // Return -1 if the sequence is empty
     LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
@@ -692,12 +694,14 @@ extern "C" {
     // This will be applied:
     //   - lazily on next llama_decode()
     //   - explicitly with llama_kv_self_update()
+    // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
     LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
 
     // Check if the context supports KV cache shifting
     LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+    // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
     LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
 
     //
index b98e3256c390d7942644d30545fb3cea8af072f7..6a19a243118d344bfd9f33a881a356dc74929138 100644 (file)
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
             break;
         }
     }
-    ubatch_token.resize(!has_embd ? n_ubatch : 0);
-    ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
-    ubatch_pos.resize(n_ubatch);
-    ubatch_n_seq_id.resize(n_ubatch);
-    ubatch_seq_id.resize(n_ubatch);
-    ubatch_output.resize(n_ubatch);
+
+    udatas.push_back({});
+
+    auto & udata = udatas.back();
+
+    udata.token.resize(!has_embd ? n_ubatch : 0);
+    udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
+    udata.pos.resize(n_ubatch);
+    udata.n_seq_id.resize(n_ubatch);
+    udata.seq_id.resize(n_ubatch);
+    udata.output.resize(n_ubatch);
+
     llama_ubatch ubatch = {
         /*equal_seqs   =*/ true,
         /*n_tokens     =*/ 0,
         /*n_seq_tokens =*/ 0,
         /*n_seqs       =*/ 0,
-        /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
-        /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
-        /*pos          =*/ ubatch_pos.data(),
-        /*n_seq_id     =*/ ubatch_n_seq_id.data(),
-        /*seq_id       =*/ ubatch_seq_id.data(),
-        /*output       =*/ ubatch_output.data(),
+        /*token        =*/ !has_embd ? udata.token.data() : nullptr,
+        /*embd         =*/ has_embd  ? udata.embd.data()  : nullptr,
+        /*pos          =*/ udata.pos.data(),
+        /*n_seq_id     =*/ udata.n_seq_id.data(),
+        /*seq_id       =*/ udata.seq_id.data(),
+        /*output       =*/ udata.output.data(),
     };
+
     return ubatch;
 }
 
index 6305051b62b794499a774f522169414d179e6989..b8260b94fd2d0aaf301347fdf70af299251556ef 100644 (file)
@@ -11,15 +11,15 @@ struct llama_ubatch {
     bool equal_seqs;
     // TODO: whole_seqs for embeddings?
 
-    uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
+    uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
     uint32_t n_seq_tokens; // tokens per sequence
     uint32_t n_seqs;
 
     llama_token  *  token;    // [n_tokens]
     float        *  embd;     // [n_embd, n_tokens]
     llama_pos    *  pos;      // [n_tokens]
-    int32_t      *  n_seq_id; // [n_seqs]
-    llama_seq_id ** seq_id;   // [n_seqs]
+    int32_t      *  n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
+    llama_seq_id ** seq_id;   // [n_seqs] // TODO: become llama_seq_id * seq_id;
     int8_t       *  output;   // [n_tokens]
 };
 
@@ -49,13 +49,18 @@ struct llama_sbatch {
 
     const llama_batch * batch = nullptr;
 
-    // buffers for the ubatch
-    std::vector<llama_token>    ubatch_token;
-    std::vector<float>          ubatch_embd;
-    std::vector<llama_pos>      ubatch_pos;
-    std::vector<int32_t>        ubatch_n_seq_id;
-    std::vector<llama_seq_id *> ubatch_seq_id;
-    std::vector<int8_t>         ubatch_output;
+    // buffers for the ubatches
+    // TODO: very hacky, this needs a complete rework
+    struct ubatch_data {
+        std::vector<llama_token>    token;
+        std::vector<float>          embd;
+        std::vector<llama_pos>      pos;
+        std::vector<int32_t>        n_seq_id;
+        std::vector<llama_seq_id *> seq_id;
+        std::vector<int8_t>         output;
+    };
+
+    std::vector<ubatch_data> udatas;
 
     llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
 
index e153351af38093a0b788dfc576f80d8301db1946..808fe5991088ce7e1e9db6a144ee810fb75621c0 100644 (file)
@@ -6,9 +6,10 @@
 #include "llama-model.h"
 #include "llama-kv-cache.h"
 
+#include <cinttypes>
 #include <cstring>
+#include <limits>
 #include <stdexcept>
-#include <cinttypes>
 
 //
 // llama_context
@@ -259,15 +260,9 @@ llama_context::llama_context(
 
     // reserve worst-case graph
     if (!hparams.vocab_only && memory) {
-        const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+        const uint32_t n_seqs = cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-
-        // restore later
-        // TODO: something cleaner
-        const auto n_outputs_save = n_outputs;
-
         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
 
         int n_splits_pp = -1;
@@ -279,23 +274,17 @@ llama_context::llama_context(
         // simulate full KV cache
         llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-        kv_self->set_full();
+        const auto kv_state = kv_self->init_full();
+        if (!kv_state) {
+            throw std::runtime_error("failed to initialize KV cache");
+        }
 
         cross.v_embd.clear();
 
         // reserve pp graph first so that buffers are only allocated once
         {
-            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-
-            // max number of outputs
-            n_outputs = ubatch_pp.n_tokens;
-
-            LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
-
-            auto * gf = graph_init();
-            graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
-
-            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
 
@@ -305,16 +294,8 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-
-            n_outputs = ubatch_tg.n_tokens;
-
-            LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
-
-            auto * gf = graph_init();
-            graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
-
-            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+            auto * gf = graph_reserve(1, 1, 1, kv_state.get());
+            if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
 
@@ -324,22 +305,12 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
-            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-
-            n_outputs = ubatch_pp.n_tokens;
-
-            LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
-
-            auto * gf = graph_init();
-            graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
-
-            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
         }
 
-        n_outputs = n_outputs_save;
-
         for (size_t i = 0; i < backend_ptrs.size(); ++i) {
             ggml_backend_t             backend = backend_ptrs[i];
             ggml_backend_buffer_type_t buft    = backend_buft[i];
@@ -454,33 +425,25 @@ const llama_kv_cache * llama_context::get_kv_self() const {
 }
 
 void llama_context::kv_self_update() {
-    bool need_reserve = false;
+    if (!memory) {
+        return;
+    }
 
     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-    need_reserve = kv_self->update(*this);
-
-    // reserve a worst case graph if needed
-    if (need_reserve) {
-        LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
-
-        // build worst-case graph
-        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-        uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
-        // simulate full KV cache
-        kv_self->set_full();
-
-        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+    if (kv_self->update(*this)) {
+        // if the KV cache did any computation, we have to reserve a new worst-case graph
+        const auto kv_state = kv_self->init_full();
+        if (!kv_state) {
+            throw std::runtime_error("failed to initialize KV cache");
+        }
 
-        auto * gf = graph_init();
-        graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
+        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-        // initialize scheduler with the worst-case graph
-        ggml_backend_sched_reset(sched.get());
-        if (!ggml_backend_sched_reserve(sched.get(), gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+        if (!gf) {
+            LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
         }
     }
 }
@@ -676,6 +639,49 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
+llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
+    if (mstate && !mstate->apply()) {
+        LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
+        ret = GGML_STATUS_FAILED;
+        return nullptr;
+    }
+
+    auto * gf = graph_init();
+    if (!gf) {
+        LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
+        ret = GGML_STATUS_FAILED;
+        return nullptr;
+    }
+
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
+    if (!res) {
+        LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
+        ret = GGML_STATUS_FAILED;
+        return nullptr;
+    }
+
+    // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+
+    if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
+        LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+        ret = GGML_STATUS_ALLOC_FAILED;
+        return nullptr;
+    }
+
+    res->set_inputs(&ubatch);
+
+    const auto status = graph_compute(gf, ubatch.n_tokens > 1);
+    if (status != GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
+        ret = status;
+        return nullptr;
+    }
+
+    ret = GGML_STATUS_SUCCESS;
+
+    return res;
+}
+
 int llama_context::encode(llama_batch & inp_batch) {
     if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -737,8 +743,6 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     n_outputs = n_tokens;
 
-    //batch_manager->prepare(ubatch);
-
     ggml_backend_sched_reset(sched.get());
     ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
@@ -749,26 +753,18 @@ int llama_context::encode(llama_batch & inp_batch) {
     //       ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
     cparams.causal_attn = false;
 
-    auto * gf = graph_init();
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
-
-    ggml_backend_sched_alloc_graph(sched.get(), gf);
-
-    res->set_inputs(&ubatch);
+    ggml_status status;
+    const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
 
     cparams.causal_attn = causal_attn_org;
 
-    const auto compute_status = graph_compute(gf, n_tokens > 1);
-    switch (compute_status) {
-        case GGML_STATUS_SUCCESS:
-            break;
-        case GGML_STATUS_ABORTED:
-            return 2;
-        case GGML_STATUS_ALLOC_FAILED:
-            return -2;
-        case GGML_STATUS_FAILED:
-        default:
-            return -3;
+    if (!res) {
+        switch (status) {
+            case GGML_STATUS_ABORTED:      return  2;
+            case GGML_STATUS_ALLOC_FAILED: return -2;
+            case GGML_STATUS_FAILED:       return -3;
+            case GGML_STATUS_SUCCESS:      GGML_ABORT("should not happen");
+        }
     }
 
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -889,8 +885,6 @@ int llama_context::decode(llama_batch & inp_batch) {
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
 
-    llama_kv_cache_guard kv_guard(kv_self);
-
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
     // TODO: move the validation to the llama_batch_allocr
@@ -936,7 +930,28 @@ int llama_context::decode(llama_batch & inp_batch) {
         n_outputs_all = 1;
     }
 
-    llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
+    // handle any pending defrags/shifts
+    kv_self_update();
+
+    auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+    if (!kv_state) {
+        return -2;
+    }
+
+    switch (kv_state->get_status()) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+            {
+            } break;
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+            {
+                // not a fatal error, we can re-try with a different batch
+                return 1;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return -2;
+            }
+    }
 
     // reserve output buffer
     if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -944,13 +959,10 @@ int llama_context::decode(llama_batch & inp_batch) {
         return -2;
     };
 
-    // handle any pending defrags/shifts
-    kv_self_update();
-
     int64_t n_outputs_prev = 0;
 
-    while (sbatch.n_tokens > 0) {
-        llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
+    do {
+        const auto & ubatch = kv_state->get_ubatch();
 
         // count the outputs in this u_batch
         {
@@ -969,33 +981,37 @@ int llama_context::decode(llama_batch & inp_batch) {
             n_outputs = n_outputs_new;
         }
 
-        // find KV slot
-        if (!kv_self->find_slot(ubatch)) {
-            return 1;
-        }
-
         ggml_backend_sched_reset(sched.get());
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
-        auto * gf = graph_init();
-        auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
+        ggml_status status;
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
 
-        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+        if (!res) {
+            // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
+            llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
 
-        ggml_backend_sched_alloc_graph(sched.get(), gf);
+            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+                const auto & seq_id = ubatch.seq_id[i][0];
 
-        res->set_inputs(&ubatch);
+                pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
+            }
+
+            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+                if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
+                    continue;
+                }
+
+                LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
+
+                llama_kv_self_seq_rm(this, s, pos_min[s], -1);
+            }
 
-        const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
-        if (compute_status != GGML_STATUS_SUCCESS) {
-            switch (compute_status) {
-                case GGML_STATUS_ABORTED:
-                    return 2;
-                case GGML_STATUS_ALLOC_FAILED:
-                    return -2;
-                case GGML_STATUS_FAILED:
-                default:
-                    return -3;
+            switch (status) {
+                case GGML_STATUS_ABORTED:      return  2;
+                case GGML_STATUS_ALLOC_FAILED: return -2;
+                case GGML_STATUS_FAILED:       return -3;
+                case GGML_STATUS_SUCCESS:      GGML_ABORT("should not happen");
             }
         }
 
@@ -1082,10 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
 
         n_outputs_prev += n_outputs;
-    }
-
-    // finalize the batch processing
-    kv_guard.commit();
+    } while (kv_state->next());
 
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     n_outputs = n_outputs_all;
@@ -1094,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     {
         bool sorted_output = true;
 
-        auto & out_ids = sbatch.out_ids;
+        auto & out_ids = kv_state->out_ids();
 
         GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
 
@@ -1254,11 +1267,52 @@ ggml_cgraph * llama_context::graph_init() {
     return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
 }
 
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
+    LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
+
+    if (n_tokens % n_seqs != 0) {
+        n_tokens = (n_tokens / n_seqs) * n_seqs;
+        n_outputs = std::min(n_outputs, n_tokens);
+
+        LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
+    }
+
+    // store the n_outputs as it is, and restore it afterwards
+    // TODO: not sure if needed, might simplify in the future by removing this
+    const auto save_n_outputs = this->n_outputs;
+
+    this->n_outputs = n_outputs;
+
+    llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+    llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+    auto * gf = graph_init();
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
+
+    this->n_outputs = save_n_outputs;
+
+    if (!res) {
+        LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
+        return nullptr;
+    }
+
+    ggml_backend_sched_reset(sched.get());
+
+    // initialize scheduler with the specified graph
+    if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+        LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+        return nullptr;
+    }
+
+    return gf;
+}
+
 llm_graph_result_ptr llama_context::graph_build(
-            ggml_context * ctx,
-             ggml_cgraph * gf,
-      const llama_ubatch & ubatch,
-            llm_graph_type gtype) {
+                    ggml_context * ctx,
+                     ggml_cgraph * gf,
+              const llama_ubatch & ubatch,
+                  llm_graph_type   gtype,
+      const llama_memory_state_i * mstate) {
     return model.build_graph(
             {
                 /*.ctx         =*/ ctx,
@@ -1270,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build(
                 /*.backend_cpu =*/ backend_cpu,
                 /*.cvec        =*/ &cvec,
                 /*.loras       =*/ &loras,
-                /*.memory      =*/ memory.get(),
+                /*.mstate      =*/ mstate,
                 /*.cross       =*/ &cross,
                 /*.n_outputs   =*/ n_outputs,
                 /*.cb          =*/ graph_get_cb(),
@@ -1951,7 +2005,6 @@ void llama_context::opt_epoch_iter(
     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
     kv_self->clear();
-    llama_kv_cache_guard kv_guard(kv_self);
 
     for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
         batch.n_tokens = n_batch;
@@ -1974,7 +2027,11 @@ void llama_context::opt_epoch_iter(
 
         int64_t n_outputs_all = n_tokens_all;
 
-        llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
+        auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+            LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
+            break;
+        }
 
         // reserve output buffer
         if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1982,20 +2039,19 @@ void llama_context::opt_epoch_iter(
             GGML_ABORT("TODO: handle this error");
         };
 
-        for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
-            llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
+        uint32_t pos_batch = 0;
+        do {
+            const auto & ubatch = kv_state->get_ubatch();
 
             n_outputs = ubatch.n_tokens;
 
-            // TODO: not sure if this is needed
-            if (!kv_self->find_slot(ubatch)) {
-                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-                GGML_ABORT("TODO: handle this error");
+            if (!kv_state->apply()) {
+                LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
+                break;
             }
 
             auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2010,6 +2066,7 @@ void llama_context::opt_epoch_iter(
             }
             ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
             ggml_opt_alloc(opt_ctx, train);
+
             res->set_inputs(&ubatch);
             {
                 struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2084,10 @@ void llama_context::opt_epoch_iter(
                 callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
             }
             ggml_free(ctx_compute_opt);
-        }
-    }
 
-    kv_guard.commit();
+            pos_batch += ubatch.n_tokens;
+        } while (kv_state->next());
+    }
 }
 
 void llama_context::opt_epoch(
index c0ceacb10ce6f56986cfa634a5700b664957a80b..5b79bafa75db755e56920b8b2940e7ef0f1da6ce 100644 (file)
@@ -18,6 +18,9 @@ struct llama_kv_cache;
 class llama_io_read_i;
 class llama_io_write_i;
 
+class llama_memory_i;
+class llama_memory_state_i;
+
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
     llama_context(
@@ -47,6 +50,7 @@ struct llama_context {
           llama_kv_cache * get_kv_self();
     const llama_kv_cache * get_kv_self() const;
 
+    // TODO: remove
     void kv_self_update();
 
     enum llama_pooling_type pooling_type() const;
@@ -88,6 +92,16 @@ struct llama_context {
                 int32_t   il_start,
                 int32_t   il_end);
 
+    // process a single ubatch with a specific graph type
+    // if memory_state is provided, it will be applied first to the context's memory
+    // ret contains the status of the graph computation
+    // returns nullptr only if ret != GGML_STATUS_SUCCESS
+    llm_graph_result_ptr process_ubatch(
+              const llama_ubatch & ubatch,
+                  llm_graph_type   gtype,
+            llama_memory_state_i * mstate,
+                     ggml_status & ret);
+
     int encode(llama_batch & inp_batch);
     int decode(llama_batch & inp_batch);
 
@@ -180,16 +194,18 @@ public:
     ggml_cgraph * graph_init();
 
     // returns the result of ggml_backend_sched_graph_compute_async execution
-    ggml_status graph_compute(
-            ggml_cgraph * gf,
-                   bool   batched);
+    ggml_status graph_compute(ggml_cgraph * gf, bool batched);
+
+    // reserve a graph with a dummy ubatch of the specified size
+    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
 
 private:
     llm_graph_result_ptr graph_build(
-            ggml_context * ctx,
-             ggml_cgraph * gf,
-      const llama_ubatch & ubatch,
-          llm_graph_type   gtype);
+                    ggml_context * ctx,
+                     ggml_cgraph * gf,
+              const llama_ubatch & ubatch,
+                  llm_graph_type   gtype,
+      const llama_memory_state_i * mstate);
 
     llm_graph_cb graph_get_cb() const;
 
index 7c383e2eb3f27e53b5f0678f33ec724b94456714..b30f6fb4f4145f3441a615d94ab2cf948a5fcda9 100644 (file)
@@ -83,7 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
     if (pos_bucket) {
-        kv_self->set_input_pos_bucket(pos_bucket, ubatch);
+        kv_state->set_input_pos_bucket(pos_bucket, ubatch);
     }
 }
 
@@ -234,7 +234,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
 void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_kv = kv_self->n;
+    const int64_t n_kv = kv_state->get_n_kv();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -242,7 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_kv; ++i) {
-            data[i] = kv_self->s_copy(i);
+            data[i] = kv_state->s_copy(i);
         }
     }
 }
@@ -250,7 +250,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
 void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_kv = kv_self->n;
+    const int64_t n_kv = kv_state->get_n_kv();
 
     if (s_mask) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
@@ -258,7 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
 
         // clear unused states
         for (int i = 0; i < n_kv; ++i) {
-            data[i] = kv_self->s_mask(i);
+            data[i] = kv_state->s_mask(i);
         }
     }
 }
@@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 }
 
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 
     if (self_kq_mask_swa) {
-        kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+        kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 }
 
@@ -448,7 +448,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     backend_cpu      (params.backend_cpu),
     cvec             (params.cvec),
     loras            (params.loras),
-    memory           (params.memory),
+    mstate           (params.mstate),
     cross            (params.cross),
     cb_func          (params.cb),
     res              (std::make_unique<llm_graph_result>()) {
@@ -954,11 +954,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_copy() const {
-    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
-    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
+    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
 
-    const auto n_kv = kv_self->n;
+    const auto n_kv = kv_state->get_n_kv();
 
     auto & cur = inp->s_copy;
 
@@ -971,11 +971,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_mask() const {
-    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
-    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
+    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
 
-    const auto n_kv = kv_self->n;
+    const auto n_kv = kv_state->get_n_kv();
 
     auto & cur = inp->s_mask;
 
@@ -1025,11 +1025,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
 
-    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
+    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
 
-    const auto n_kv = kv_self->get_n();
+    const auto n_kv = kv_state->get_n_kv();
 
     auto & cur = inp->pos_bucket;
 
@@ -1231,14 +1231,14 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
 
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = kv_self->get_n();
+        const auto n_kv = kv_state->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,19 +1268,19 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_self->get_k(ctx0, il);
-    ggml_tensor * v = kv_self->get_v(ctx0, il);
+    ggml_tensor * k = kv_state->get_k(ctx0, il);
+    ggml_tensor * v = kv_state->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1301,12 +1301,12 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
 
     {
-        const auto n_kv = kv_self->get_kv_base()->get_n();
+        const auto n_kv = kv_state->get_base()->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     {
         GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
 
-        const auto n_kv = kv_self->get_kv_swa()->get_n();
+        const auto n_kv = kv_state->get_swa()->get_n_kv();
 
         inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1348,23 +1348,23 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const bool is_swa = hparams.is_swa(il);
+    const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
 
-    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
+    const bool is_swa = hparams.is_swa(il);
 
-    const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
+    const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv->get_k(ctx0, il);
-    ggml_tensor * v = kv->get_v(ctx0, il);
+    ggml_tensor * k = kv_state->get_k(ctx0, il);
+    ggml_tensor * v = kv_state->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1446,12 +1446,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
          ggml_tensor * state_mask,
              int32_t   n_state,
              int32_t   n_seqs) const {
-    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
-    const auto n_kv    = kv_self->n;
-    const auto kv_head = kv_self->head;
+    const auto n_kv    = kv_state->get_n_kv();
+    const auto kv_head = kv_state->get_head();
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
 
     // copy states
     // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -1478,13 +1478,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
          ggml_tensor * state_mask,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
 
     const int64_t n_seqs  = ubatch.n_seqs;
 
-    ggml_tensor * token_shift_all = kv_self->k_l[il];
+    ggml_tensor * token_shift_all = kv_state->get_k_l(il);
 
     ggml_tensor * token_shift = build_copy_mask_state(
             gf, token_shift_all, state_copy, state_mask,
@@ -1499,19 +1499,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
 
     const int64_t n_seqs = ubatch.n_seqs;
 
-    const auto kv_head = kv_self->head;
+    const auto kv_head = kv_state->get_head();
 
     return ggml_cpy(
         ctx0,
         ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
-        ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
+        ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
     );
 }
 
index 2b85bb25befbac4e9c3294c06d09e641c856a7ef..d1c5dd1bf036f6ed1a8272d4da29b67901086f5d 100644 (file)
@@ -17,10 +17,11 @@ struct ggml_tensor;
 struct llama_ubatch;
 struct llama_cparams;
 
-class llama_memory_i;
-class llama_kv_cache_unified;
-class llama_kv_cache_unified_iswa;
-class llama_kv_cache_recurrent;
+class llama_memory_state_i;
+
+class llama_kv_cache_unified_state;
+class llama_kv_cache_unified_iswa_state;
+class llama_kv_cache_recurrent_state;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
+            const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +142,7 @@ public:
     ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
 
     const llama_hparams & hparams;
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_unified_state * kv_state;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -188,26 +189,26 @@ public:
 
 class llm_graph_input_s_copy : public llm_graph_input_i {
 public:
-    llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
     virtual ~llm_graph_input_s_copy() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_kv_cache_recurrent * kv_self;
+    const llama_kv_cache_recurrent_state * kv_state;
 };
 
 class llm_graph_input_s_mask : public llm_graph_input_i {
 public:
-    llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
     virtual ~llm_graph_input_s_mask() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_mask; // F32 [1, n_kv]
 
-    const llama_kv_cache_recurrent * kv_self;
+    const llama_kv_cache_recurrent_state * kv_state;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +248,10 @@ public:
     llm_graph_input_attn_kv_unified(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified * kv_self) :
+            const llama_kv_cache_unified_state * kv_state) :
         hparams(hparams),
         cparams(cparams),
-        kv_self(kv_self) {
+        kv_state(kv_state) {
     }
     ~llm_graph_input_attn_kv_unified() = default;
 
@@ -264,7 +265,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_unified_state * kv_state;
 };
 
 class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +273,10 @@ public:
     llm_graph_input_attn_kv_unified_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa * kv_self) :
+            const llama_kv_cache_unified_iswa_state * kv_state) :
         hparams(hparams),
         cparams(cparams),
-        kv_self(kv_self) {
+        kv_state(kv_state) {
     }
     ~llm_graph_input_attn_kv_unified_iswa() = default;
 
@@ -292,7 +293,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified_iswa * kv_self;
+    const llama_kv_cache_unified_iswa_state * kv_state;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -383,10 +384,10 @@ struct llm_graph_params {
     ggml_backend_sched_t sched;
     ggml_backend_t backend_cpu;
 
-    const llama_adapter_cvec  * cvec;
-    const llama_adapter_loras * loras;
-    const llama_memory_i      * memory;
-    const llama_cross         * cross;
+    const llama_adapter_cvec   * cvec;
+    const llama_adapter_loras  * loras;
+    const llama_memory_state_i * mstate;
+    const llama_cross          * cross;
 
     int32_t n_outputs;
 
@@ -435,10 +436,10 @@ struct llm_graph_context {
 
     ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 
-    const llama_adapter_cvec  * cvec;
-    const llama_adapter_loras * loras;
-    const llama_memory_i      * memory;
-    const llama_cross         * cross;
+    const llama_adapter_cvec   * cvec;
+    const llama_adapter_loras  * loras;
+    const llama_memory_state_i * mstate;
+    const llama_cross          * cross;
 
     const llm_graph_cb & cb_func;
 
index 766f8d079afb26e5242cb1f1f519b597069028a9..86c4f2816f8097282fd1e220e88a2db5470c538b 100644 (file)
 // llama_kv_cache_unified
 //
 
-uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
-
 llama_kv_cache_unified::llama_kv_cache_unified(
         const llama_model &  model,
           layer_filter_cb && filter,
@@ -293,26 +288,81 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     return cells.seq_pos_max(seq_id);
 }
 
-void llama_kv_cache_unified::restore() {
-    for (auto & state : recovery.states) {
-        cells.set(state.i, state.cells);
+llama_memory_state_ptr llama_kv_cache_unified::init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) {
+    GGML_UNUSED(embd_pooled);
+
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+
+    std::vector<llama_ubatch> ubatches;
+    while (sbatch.n_tokens > 0) {
+        ubatches.push_back(sbatch.split_simple(n_ubatch));
+    }
+
+    auto heads = prepare(ubatches);
+    if (heads.empty()) {
+        return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    recovery.clear();
+    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+            this, std::move(sbatch), std::move(heads), std::move(ubatches));
 }
 
-void llama_kv_cache_unified::commit() {
-    if (recovery.states.empty()) {
-        LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
-                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
-        return;
+llama_memory_state_ptr llama_kv_cache_unified::init_full() {
+    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+}
+
+std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
+    std::vector<uint32_t> res;
+
+    struct state {
+        uint32_t head_old; // old position of the head, before placing the ubatch
+        uint32_t head_new; // new position of the head, after placing the ubatch
+
+        llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
+    };
+
+    // remember the old state of the cells so we can restore it in the end
+    std::vector<state> states;
+
+    bool success = true;
+
+    for (const auto & ubatch : ubatches) {
+        // only find a suitable slot for the ubatch. don't modify the cells yet
+        const int32_t head_new = find_slot(ubatch);
+        if (head_new < 0) {
+            success = false;
+            break;
+        }
+
+        // remeber the position that we found
+        res.push_back(head_new);
+
+        // store the old state of the cells in the recovery stack
+        states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
+
+        // now emplace the ubatch
+        apply_ubatch(head_new, ubatch);
+    }
+
+    // iterate backwards and restore the cells to their original state
+    for (auto it = states.rbegin(); it != states.rend(); ++it) {
+        cells.set(it->head_new, it->cells);
+        head = it->head_old;
+    }
+
+    if (!success) {
+        return {};
     }
 
-    recovery.clear();
+    return res;
 }
 
 bool llama_kv_cache_unified::update(llama_context & lctx) {
-    bool need_reserve = false;
+    bool updated = false;
 
     auto * sched = lctx.get_sched();
 
@@ -330,14 +380,24 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             auto * gf = lctx.graph_init();
 
             auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            if (!res) {
+                LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
+                return updated;
+            }
 
-            ggml_backend_sched_alloc_graph(sched, gf);
+            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
+                return updated;
+            }
 
             res->set_inputs(nullptr);
 
-            lctx.graph_compute(gf, false);
+            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
+                return updated;
+            }
 
-            need_reserve = true;
+            updated = true;
         }
 
         cells.reset_shift();
@@ -352,26 +412,38 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             auto * gf = lctx.graph_init();
 
             auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            if (!res) {
+                LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
+                return updated;
+            }
 
-            ggml_backend_sched_alloc_graph(sched, gf);
+            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
+                return updated;
+            }
 
             res->set_inputs(nullptr);
 
-            lctx.graph_compute(gf, false);
+            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+                LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
+                return updated;
+            }
 
-            need_reserve = true;
+            updated = true;
         }
 
         do_defrag = false;
     }
 
-    return need_reserve;
+    return updated;
 }
 
 void llama_kv_cache_unified::defrag_sched(float thold) {
+    const auto n_kv = cells.used_max_p1();
+
     // - do not defrag small contexts (i.e. < 2048 tokens)
     // - count the padding towards the number of used tokens
-    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
+    const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
 
     // queue defragmentation for next llama_kv_cache_update
     if (fragmentation > thold) {
@@ -381,55 +453,37 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
     }
 }
 
-void llama_kv_cache_unified::set_full() {
-    n = cells.size();
-
-    // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
-    //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
-    //   we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
-    //   setting it to 0 is the simplest way to achieve that
-    // ref: https://github.com/ggml-org/llama.cpp/issues/13359
-    head = 0;
-}
-
-llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
-    return llama_sbatch(batch, hparams.n_embd, true, logits_all);
-}
-
-llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
-    GGML_UNUSED(embd_pooled);
-    return sbatch.split_simple(n_ubatch);
-}
-
-bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
+int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
+    uint32_t head_cur = this->head;
+
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > cells.get_used() + 2*ubatch.n_tokens) {
-        head = 0;
+    if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
+        head_cur = 0;
     }
 
     // otherwise, one cell per token.
 
     if (n_tokens > cells.size()) {
         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-        return false;
+        return -1;
     }
 
 //#define FIND_SLOT_DEBUG 1
 #if FIND_SLOT_DEBUG
-    LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
+    LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
 
     // for debugging
     {
         std::string ss;
         if (n_swa > 0) {
-            for (uint32_t i = 0; i < size; ++i) {
+            for (uint32_t i = 0; i < cells.size(); ++i) {
                 if (cells.is_empty(i)) {
                     ss += '.';
                 } else {
-                    ss += 'x';
+                    ss += std::to_string(cells.seq_get(i));
                 }
                 if (i%256 == 255) {
                     ss += '\n';
@@ -438,23 +492,70 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
         }
         LLAMA_LOG_WARN("\n%s\n", ss.c_str());
     }
+
+    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        if (cells.seq_pos_min(s) < 0) {
+            continue;
+        }
+
+        LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
+    }
 #endif
 
     uint32_t n_tested = 0;
 
     while (true) {
-        if (head + n_tokens > cells.size()) {
-            n_tested += cells.size() - head;
-            head = 0;
+        if (head_cur + n_tokens > cells.size()) {
+            n_tested += cells.size() - head_cur;
+            head_cur = 0;
             continue;
         }
 
+        // keep track of what the minimum sequence positions would be if we accept the ubatch
+        llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            seq_pos_min[s] = cells.seq_pos_min(s);
+        }
+
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            // TODO: improve to accept cells that are masked by the SWA
-            if (!cells.is_empty(head + i)) {
+            const llama_pos    pos    = ubatch.pos[i];
+            const llama_seq_id seq_id = ubatch.seq_id[i][0];
+
+            // can we use this cell? either:
+            //  - the cell is empty
+            //  - the cell is occupied only by one sequence:
+            //    - mask causally, if the sequence is the same as the one we are inserting
+            //    - mask SWA, using current max pos for that sequence in the cache
+            //                always insert in the cell with minimum pos
+            bool can_use = cells.is_empty(head_cur + i);
+
+            if (!can_use && cells.seq_count(head_cur + i) == 1) {
+                const llama_pos pos_cell = cells.pos_get(head_cur + i);
+
+                // causal mask
+                if (cells.seq_has(head_cur + i, seq_id)) {
+                    can_use = pos_cell >= pos;
+                }
+
+                if (!can_use) {
+                    const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
+
+                    // SWA mask
+                    // note: we insert only in the cell with minimum pos in order to preserve the invariant that
+                    //       all positions between [pos_min, pos_max] for each sequence will be present in the cache
+                    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
+                    if (pos_cell == seq_pos_min[seq_id_cell] &&
+                        is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                        seq_pos_min[seq_id_cell]++;
+                        can_use = true;
+                    }
+                }
+            }
+
+            if (!can_use) {
                 found = false;
-                head     += i + 1;
+                head_cur += i + 1;
                 n_tested += i + 1;
                 break;
             }
@@ -466,58 +567,55 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
 
         if (n_tested >= cells.size()) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return false;
+            return -1;
         }
     }
 
-    // store the old state of the cells in the recovery stack
-    recovery.states.push_back({head, cells.cp(head, n_tokens)});
+    return head_cur;
+}
+
+void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
+    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+        if (!cells.is_empty(head_cur + i)) {
+            cells.rm(head_cur + i);
+        }
 
-    for (uint32_t i = 0; i < n_tokens; ++i) {
-        cells.pos_set(head + i, ubatch.pos[i]);
+        cells.pos_set(head_cur + i, ubatch.pos[i]);
 
         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
-            cells.seq_add(head + i, ubatch.seq_id[i][j]);
+            cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
         }
     }
 
-    // a heuristic, to avoid attending the full cache if it is not yet utilized
-    // after enough generations, the benefit from this heuristic disappears
-    // if we start defragmenting the cache, the benefit from this will be more important
-    n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
-
-#ifdef FIND_SLOT_DEBUG
-    LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
-#endif
-
-    return true;
+    // move the head at the end of the slot
+    head = head_cur + ubatch.n_tokens;
 }
 
 bool llama_kv_cache_unified::get_can_shift() const {
     return true;
 }
 
-uint32_t llama_kv_cache_unified::get_n() const {
-    return n;
-}
-
 uint32_t llama_kv_cache_unified::get_size() const {
     return cells.size();
 }
 
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
+uint32_t llama_kv_cache_unified::get_n_kv() const {
+    return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
+}
+
+ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
 
     return ggml_view_3d(ctx, k,
-            hparams.n_embd_head_k, hparams.n_head_kv(il), n,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
             ggml_row_size(k->type, hparams.n_embd_head_k),
             ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
             0);
 }
 
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
@@ -525,7 +623,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
     if (!v_trans) {
         // note: v->nb[1] <= v->nb[2]
         return ggml_view_3d(ctx, v,
-                hparams.n_embd_head_v, hparams.n_head_kv(il), n,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
                 ggml_row_size(v->type, hparams.n_embd_head_v),    // v->nb[1]
                 ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
                 0);
@@ -533,13 +631,13 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
 
     // note: v->nb[1] > v->nb[2]
     return ggml_view_3d(ctx, v,
-            n, hparams.n_head_kv(il), hparams.n_embd_head_v,
+            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
             ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
             ggml_row_size(v->type, v->ne[1]),                       // v->nb[2]
             0);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
@@ -548,12 +646,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
 
     ggml_tensor * k_view = ggml_view_1d(ctx, k,
             n_tokens*hparams.n_embd_k_gqa(il),
-            ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
+            ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
 
     return ggml_cpy(ctx, k_cur, k_view);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
@@ -567,12 +665,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
     if (!v_trans) {
         v_view = ggml_view_1d(ctx, v,
                 n_tokens*hparams.n_embd_v_gqa(il),
-                ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
+                ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
     } else {
         // note: the V cache is transposed when not using flash attention
         v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
                 (v->ne[1])*ggml_element_size(v),
-                (    head)*ggml_element_size(v));
+                (head_cur)*ggml_element_size(v));
 
         v_cur = ggml_transpose(ctx, v_cur);
     }
@@ -580,33 +678,6 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
     return ggml_cpy(ctx, v_cur, v_view);
 }
 
-void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
-    // no pruning is needed when the cache does not use SWA
-    GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
-
-    int n_attended = 0;
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.seq_has(i, seq_id)) {
-            continue;
-        }
-
-        const llama_pos p0 = cells.pos_get(i);
-
-        if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
-            n_attended++;
-        }
-
-        if (is_masked_swa(p0, pmax)) {
-            cells.seq_rm(i, seq_id);
-        }
-    }
-
-    if (n_attended < std::min<int>(n_swa, pmin)) {
-        LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
-    }
-}
-
 void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     const int64_t n_tokens     = ubatch->n_tokens;
     const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -615,7 +686,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
 
-    const int64_t n_kv = n;
+    const auto n_kv = dst->ne[0];
 
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -636,7 +707,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
             for (int j = 0; j < n_seq_tokens; ++j) {
                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
 
-                for (int i = 0; i < n_kv; ++i) {
+                for (uint32_t i = 0; i < n_kv; ++i) {
                     float f = 0.0f;
 
                     bool masked = false;
@@ -672,7 +743,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
         // mask padded tokens
         if (data) {
             for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (int j = 0; j < n_kv; ++j) {
+                for (uint32_t j = 0; j < n_kv; ++j) {
                     data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
                 }
             }
@@ -698,7 +769,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
 
     int32_t * data = (int32_t *) dst->data;
 
-    const int64_t n_kv = n;
+    const int32_t n_kv = dst->ne[0];
 
     for (int h = 0; h < 1; ++h) {
         for (int j = 0; j < n_tokens; ++j) {
@@ -1362,20 +1433,24 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             batch.seq_id[i]   = &dest_seq_id;
         }
 
-        if (!find_slot(batch)) {
+        const auto head_cur = find_slot(batch);
+        if (head_cur < 0) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
 
-        commit();
+        apply_ubatch(head_cur, batch);
 
-        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // keep the head at the old position because we will read the KV data into it in state_read_data()
+        head = head_cur;
+
+        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head)                  == batch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
-        GGML_ASSERT(cells.seq_has(head,                  dest_seq_id));
-        GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
+        GGML_ASSERT(head_cur + cell_count <= cells.size());
+        GGML_ASSERT(cells.pos_get(head_cur)                  == batch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
+        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
     } else {
         // whole KV cache restore
 
@@ -1425,10 +1500,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
         return false;
     }
+
     if (cell_count > cells.size()) {
         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
         return false;
     }
+
     if (this->v_trans != (bool) v_trans) {
         LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
         return false;
@@ -1539,6 +1616,108 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     return true;
 }
 
+//
+// llama_kv_cache_unified_state
+//
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+            llama_memory_status status,
+            llama_kv_cache_unified * kv) : status(status), kv(kv) {
+        n_kv = kv->get_size();
+        head = 0;
+    }
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+            llama_memory_status status,
+            llama_kv_cache_unified * kv,
+            llama_sbatch sbatch,
+            std::vector<uint32_t> heads,
+            std::vector<llama_ubatch> ubatches)
+            : status(status),
+              kv(kv),
+              sbatch(std::move(sbatch)),
+              heads(std::move(heads)),
+              ubatches(std::move(ubatches)) {
+    }
+
+llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
+
+bool llama_kv_cache_unified_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_unified_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    kv->apply_ubatch(heads[i_next], ubatches[i_next]);
+
+    n_kv = kv->get_n_kv();
+    head = heads[i_next];
+
+    return true;
+}
+
+std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
+}
+
+llama_memory_status llama_kv_cache_unified_state::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_next];
+}
+
+uint32_t llama_kv_cache_unified_state::get_n_kv() const {
+    return n_kv;
+}
+
+ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
+    return kv->get_k(ctx, il, n_kv);
+}
+
+ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
+    return kv->get_v(ctx, il, n_kv);
+}
+
+ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
+    return kv->cpy_k(ctx, k_cur, il, head);
+}
+
+ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
+    return kv->cpy_v(ctx, v_cur, il, head);
+}
+
+void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
+    kv->set_input_k_shift(dst);
+}
+
+void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    kv->set_input_kq_mask(dst, ubatch, causal_attn);
+}
+
+void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_pos_bucket(dst, ubatch);
+}
+
+uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
+
 //
 // llama_kv_cache_unified_iswa
 //
@@ -1561,13 +1740,12 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
 
     uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
 
-    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
+    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
         LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
                 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
 
         size_swa = size_base;
-        do_prune = false;
     }
 
     LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
@@ -1628,31 +1806,46 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-void llama_kv_cache_unified_iswa::restore() {
-    kv_base->restore();
-    kv_swa ->restore();
-}
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+    GGML_UNUSED(embd_pooled);
 
-void llama_kv_cache_unified_iswa::commit() {
-    kv_base->commit();
-    kv_swa ->commit();
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
 
-    // slide the attention window, forgetting/pruning old tokens that are outside the window
-    if (do_prune) {
-        for (const auto & [seq_id, entry] : pending.pos) {
-            kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
-        }
+    // TODO: if we fail with split_simple, we should attempt split_equal
+
+    std::vector<llama_ubatch> ubatches;
+
+    while (sbatch.n_tokens > 0) {
+        auto ubatch = sbatch.split_simple(n_ubatch);
+
+        ubatches.push_back(ubatch);
+    }
 
+    auto heads_base = kv_base->prepare(ubatches);
+    if (heads_base.empty()) {
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    pending.clear();
+    auto heads_swa = kv_swa->prepare(ubatches);
+    if (heads_swa.empty()) {
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    assert(heads_base.size() == heads_swa.size());
+
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+            this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+}
+
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
 }
 
 bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
-    bool res = true;
+    bool res = false;
 
-    res = res & kv_base->update(lctx);
-    res = res & kv_swa ->update(lctx);
+    res = res | kv_base->update(lctx);
+    res = res | kv_swa ->update(lctx);
 
     return res;
 }
@@ -1662,68 +1855,107 @@ void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
     kv_swa ->defrag_sched(thold);
 }
 
-void llama_kv_cache_unified_iswa::set_full() {
-    kv_base->set_full();
-    kv_swa ->set_full();
+bool llama_kv_cache_unified_iswa::get_can_shift() const {
+    return kv_base->get_size() == kv_swa->get_size();
 }
 
-llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
-    pending.clear();
+void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    kv_base->state_write(io, seq_id);
+    kv_swa ->state_write(io, seq_id);
+}
 
-    if (do_prune) {
-        for (int i = 0; i < batch.n_tokens; ++i) {
-            for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                const llama_seq_id seq_id = batch.seq_id[i][s];
-                const llama_pos    pos    = batch.pos[i];
+void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    kv_base->state_read(io, seq_id);
+    kv_swa ->state_read(io, seq_id);
+}
 
-                if (pending.pos.find(seq_id) == pending.pos.end()) {
-                    pending.pos[seq_id].pmin = pos;
-                    pending.pos[seq_id].pmax = pos;
-                } else {
-                    pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
-                    pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
-                }
-            }
-        }
-    }
+llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
+    return kv_base.get();
+}
 
-    return llama_sbatch(batch, hparams.n_embd, true, logits_all);
+llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
+    return kv_swa.get();
 }
 
-llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
-    GGML_UNUSED(embd_pooled);
-    return sbatch.split_simple(n_ubatch);
+//
+// llama_kv_cache_unified_iswa_state
+//
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+        llama_memory_status status,
+        llama_kv_cache_unified_iswa * kv) : status(status) {
+    state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
+    state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
+}
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+        llama_memory_status status,
+        llama_kv_cache_unified_iswa * kv,
+        llama_sbatch sbatch,
+        std::vector<uint32_t> heads_base,
+        std::vector<uint32_t> heads_swa,
+        std::vector<llama_ubatch> ubatches)
+    : status(status),
+    sbatch(std::move(sbatch)),
+    ubatches(std::move(ubatches)) {
+        // note: here we copy the ubatches. not sure if this is ideal
+        state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
+        state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
+    }
+
+llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
+
+bool llama_kv_cache_unified_iswa_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    state_base->next();
+    state_swa ->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
 }
 
-bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
+bool llama_kv_cache_unified_iswa_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
     bool res = true;
 
-    res = res & kv_base->find_slot(batch);
-    res = res & kv_swa ->find_slot(batch);
+    res = res & state_base->apply();
+    res = res & state_swa ->apply();
 
     return res;
 }
 
-bool llama_kv_cache_unified_iswa::get_can_shift() const {
-    return kv_base->get_size() == kv_swa->get_size();
+std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
 }
 
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
-    kv_base->state_write(io, seq_id);
-    kv_swa ->state_write(io, seq_id);
+llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
+    return status;
 }
 
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
-    kv_base->state_read(io, seq_id);
-    kv_swa ->state_read(io, seq_id);
+const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    return ubatches[i_next];
 }
 
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
-    return kv_base.get();
+const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return state_base.get();
 }
 
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
-    return kv_swa.get();
+const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return state_swa.get();
 }
 
 //
@@ -2071,50 +2303,82 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-void llama_kv_cache_recurrent::restore() {
-    if (pending.ranges.empty()) {
-        return;
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+    GGML_UNUSED(embd_pooled);
+
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+
+    std::vector<llama_ubatch> ubatches;
+
+    while (sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = sbatch.split_seq(n_ubatch);
+        } else {
+            ubatch = sbatch.split_equal(n_ubatch);
+        }
+
+        ubatches.push_back(ubatch);
     }
 
-    seq_rm(-1, -1, -1);
-}
+    if (!prepare(ubatches)) {
+        return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
 
-void llama_kv_cache_recurrent::commit() {
-    pending.ranges.clear();
+    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
 }
 
-bool llama_kv_cache_recurrent::update(llama_context & ctx) {
-    GGML_UNUSED(ctx);
-    return false;
+llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
+    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
 }
 
-void llama_kv_cache_recurrent::defrag_sched(float thold) {
-    GGML_UNUSED(thold);
-    // noop
-}
+bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
+    // simply remember the full state because it is very small for this type of cache
+    // TODO: optimize
+    auto org_cells = cells;
+    auto org_used = used;
+    auto org_head = head;
 
-void llama_kv_cache_recurrent::set_full() {
-    n = size;
-    head = 0;
-}
+    bool success = true;
+
+    // TODO: here we have to verify that all ubatches can fit in the cells
+    //       however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
+    //         during the compute of each ubatch. to reproduce, uncomment the following loop and run:
+    //
+    //           $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
+    //
+    //       recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
+    //
+    GGML_UNUSED(ubatches);
+    //for (const auto & ubatch : ubatches) {
+    //    if (!find_slot(ubatch)) {
+    //        success = false;
+    //        break;
+    //    }
+    //}
 
-llama_sbatch llama_kv_cache_recurrent::sbatch_init(
-        const llama_batch & batch,
-        bool logits_all) {
-    return llama_sbatch(batch, hparams.n_embd, false, logits_all);
+    // restore the original state
+    cells = std::move(org_cells);
+    used = org_used;
+    head = org_head;
+
+    return success;
 }
 
-llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
-    if (embd_pooled) {
-        // Pooled embeddings cannot be split across ubatches (yet)
-        return sbatch.split_seq(n_ubatch);
-    }
+bool llama_kv_cache_recurrent::update(llama_context & lctx) {
+    GGML_UNUSED(lctx);
+    // noop
+    return false;
+}
 
-    return sbatch.split_equal(n_ubatch);
+void llama_kv_cache_recurrent::defrag_sched(float thold) {
+    GGML_UNUSED(thold);
+    // noop
 }
 
-bool llama_kv_cache_recurrent::find_slot(
-       const llama_ubatch & ubatch) {
+bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
     const uint32_t n_seqs   = ubatch.n_seqs;
 
@@ -2332,18 +2596,6 @@ float llama_kv_cache_recurrent::s_mask(int i) const {
     return res;
 }
 
-uint32_t llama_kv_cache_recurrent::cell_max() const {
-    for (uint32_t i = size; i > 0; --i) {
-        const kv_cell & cell = cells[i - 1];
-
-        if (cell.pos >= 0 && !cell.is_empty()) {
-            return i;
-        }
-    }
-
-    return 0;
-}
-
 size_t llama_kv_cache_recurrent::total_size() const {
     size_t size = 0;
     for (const auto & buf : bufs) {
@@ -2558,11 +2810,11 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
         }
         batch.n_seq_id[0] = 1;
         batch.seq_id[0] = &dest_seq_id;
+
         if (!find_slot(batch)) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
-        commit();
 
         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
@@ -2745,3 +2997,84 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
 
     return true;
 }
+
+//
+// llama_kv_cache_recurrent_state
+//
+
+llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
+        llama_memory_status status,
+        llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
+}
+
+llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
+        llama_memory_status status,
+        llama_kv_cache_recurrent * kv,
+        llama_sbatch sbatch,
+        std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
+
+llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
+
+bool llama_kv_cache_recurrent_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_recurrent_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    kv->find_slot(ubatches[i_next]);
+
+    return true;
+}
+
+std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
+}
+
+llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_next];
+}
+
+uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
+    return is_full ? kv->size : kv->n;
+}
+
+uint32_t llama_kv_cache_recurrent_state::get_head() const {
+    return is_full ? 0 : kv->head;
+}
+
+uint32_t llama_kv_cache_recurrent_state::get_size() const {
+    return kv->size;
+}
+
+ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
+    return kv->k_l[il];
+}
+
+ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
+    return kv->v_l[il];
+}
+
+int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
+    return kv->s_copy(i);
+}
+
+float llama_kv_cache_recurrent_state::s_mask(int i) const {
+    return kv->s_mask(i);
+}
index ce6261e45a6e17e0052d8fa6ffaeff32b9b91641..d2439e13603a0344b8cffa64810e1975060c42fd 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "llama.h"
 #include "llama-io.h"
+#include "llama-batch.h"
 #include "llama-graph.h"
 #include "llama-memory.h"
 #include "llama-kv-cells.h"
 
 struct llama_cparams;
 struct llama_hparams;
-struct llama_ubatch;
-struct llama_sbatch;
 struct llama_model;
 struct llama_context;
 
 struct llama_kv_cache : public llama_memory_i {
     virtual ~llama_kv_cache() = default;
 
-    // call if batch processing fails - restores the cache state
-    virtual void restore() = 0;
+    // split the input batch into a set of ubatches and verify that they can fit into the cache
+    // return a state object containing the ubatches and KV cache state required to process them
+    // check the llama_memory_state_i::get_status() for the result
+    virtual llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) = 0;
 
-    // call after successful batch processing - clears any pending state
-    virtual void commit()  = 0;
+    // simulate full cache, used for allocating worst-case compute buffers
+    virtual llama_memory_state_ptr init_full() = 0;
 
     // process any pending defrag/shift/etc. operations
     // optionally call once before processing a new batch
+    // return true if any operations were performed
     virtual bool update(llama_context & lctx) = 0;
 
     // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
-    virtual void defrag_sched(float thold) = 0;
-
-    // simulate full cache, used for allocating worst-case compute buffers
-    // TODO: remove
-    virtual void set_full() = 0;
-
-    //
-    // batch processing
+    // TODO: change to
+    //   llama_memory_state_ptr init_defrag(float thold) = 0;
     //
-
-    // =============================================================================================================
-    // TODO: refactor and simplify this [TAG: KV_API]
-
-    virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
-
-    // different KV caches require different batch splitting strategies
-    virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
-
-    // find an empty slot of size "n_tokens" in the cache
-    virtual bool find_slot(const llama_ubatch & batch) = 0;
-
-    // =============================================================================================================
+    virtual void defrag_sched(float thold) = 0;
 
     // getters
     virtual bool get_can_shift() const = 0;
@@ -69,25 +57,6 @@ struct llama_kv_cache : public llama_memory_i {
     virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
 };
 
-//
-// llama_kv_cache_guard
-//
-
-struct llama_kv_cache_guard {
-    llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
-
-    ~llama_kv_cache_guard() {
-        kv->restore();
-    }
-
-    void commit() {
-        kv->commit();
-    }
-
-private:
-    llama_kv_cache * kv;
-};
-
 //
 // llama_kv_cache_unified
 //
@@ -133,22 +102,17 @@ public:
     // llama_kv_cache
     //
 
-    void restore() override;
-    void commit()  override;
-
-    bool update(llama_context & ctx) override;
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) override;
 
-    void defrag_sched(float thold) override;
-
-    void set_full() override;
+    llama_memory_state_ptr init_full() override;
 
-    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+    bool update(llama_context & lctx) override;
 
-    // updates the cache head
-    // Note: On success, it's important that cache.head points
-    // to the first cell of the slot.
-    bool find_slot(const llama_ubatch & batch) override;
+    void defrag_sched(float thold) override;
 
     bool get_can_shift() const override;
 
@@ -161,18 +125,40 @@ public:
     // llama_kv_cache_unified specific API
     //
 
-    uint32_t get_n()    const;
     uint32_t get_size() const;
 
+    //
+    // graph_build API
+    //
+
+    uint32_t get_n_kv() const;
+
     // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
 
-    // store k_cur and v_cur in the cache based on the current head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
+    // store k_cur and v_cur in the cache based on the provided head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
+
+    //
+    // preparation API
+    //
+
+    // find places for the provided ubatches in the cache, returns the head locations
+    // return empty vector on failure
+    std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
 
-    void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
+    // return the cell position where we can insert the ubatch
+    // return -1 on failure to find a contiguous slot of kv cells
+    int32_t find_slot(const llama_ubatch & ubatch) const;
+
+    // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
+    void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
+
+    //
+    // set_input API
+    //
 
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
     void set_input_k_shift   (ggml_tensor * dst) const;
@@ -194,11 +180,9 @@ private:
     bool do_defrag = false;
     bool v_trans   = true;  // the value tensor is transposed
 
-    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
-
-    // computed before each graph build
-    // TODO: cells should start to maintain this value dynamically based on the edits
-    uint32_t n = 0;
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    uint32_t head = 0;
 
     const uint32_t n_seq_max = 1;
 
@@ -220,24 +204,6 @@ private:
     // model layer id -> KV cache layer id
     std::unordered_map<int32_t, int32_t> map_layer_ids;
 
-    // recovery information used to restore the KV cells to their original state in case of a failure
-    // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
-    //       to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
-    struct {
-        void clear() {
-            states.clear();
-        }
-
-        struct state {
-            uint32_t i;
-
-            llama_kv_cells_unified cells;
-        };
-
-        // stack with the partial states before each ubatch
-        std::vector<state> states;
-    } recovery;
-
     // defrag
     struct {
         std::vector<uint32_t> ids;
@@ -279,13 +245,88 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
+class llama_kv_cache_unified_state : public llama_memory_state_i {
+public:
+    // used for errors
+    llama_kv_cache_unified_state(llama_memory_status status);
+
+    // used to create a full-cache state
+    llama_kv_cache_unified_state(
+            llama_memory_status status,
+            llama_kv_cache_unified * kv);
+
+    // used to create a state from a batch
+    llama_kv_cache_unified_state(
+            llama_memory_status status,
+            llama_kv_cache_unified * kv,
+            llama_sbatch sbatch,
+            std::vector<uint32_t> heads,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_unified_state();
+
+    //
+    // llama_memory_state_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_unified_state specific API
+    //
+
+    uint32_t get_n_kv() const;
+
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
+
+    // store k_cur and v_cur in the cache based on the provided head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
+
+    void set_input_k_shift(ggml_tensor * dst) const;
+
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+private:
+    const llama_memory_status status;
+
+    llama_kv_cache_unified * kv;
+
+    llama_sbatch sbatch;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<uint32_t> heads;
+    std::vector<llama_ubatch> ubatches;
+
+    //
+    // data needed for building the compute graph for the current ubatch:
+    //
+
+    // a heuristic, to avoid attending the full cache if it is not yet utilized
+    // as the cache gets filled, the benefit from this heuristic disappears
+    int32_t n_kv;
+
+    // the beginning of the current slot in which the ubatch will be inserted
+    int32_t head;
+};
+
 //
 // llama_kv_cache_unified_iswa
 //
 
 // utilizes two instances of llama_kv_cache_unified
 //   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
-//   upon successful commit, the SWA cache removes old tokens outside the n_swa window
 
 class llama_kv_cache_unified_iswa : public llama_kv_cache {
 public:
@@ -322,19 +363,17 @@ public:
     // llama_kv_cache
     //
 
-    void restore() override;
-    void commit()  override;
-
-    bool update(llama_context & ctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) override;
 
-    void set_full() override;
+    llama_memory_state_ptr init_full() override;
 
-    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+    bool update(llama_context & lctx) override;
 
-    bool find_slot(const llama_ubatch & batch) override;
+    void defrag_sched(float thold) override;
 
     bool get_can_shift() const override;
 
@@ -347,58 +386,80 @@ public:
     // llama_kv_cache_unified_iswa specific API
     //
 
-    llama_kv_cache_unified * get_kv_base() const;
-    llama_kv_cache_unified * get_kv_swa () const;
+    llama_kv_cache_unified * get_base() const;
+    llama_kv_cache_unified * get_swa () const;
 
 private:
     const llama_hparams & hparams;
 
-    bool do_prune = true;
+    std::unique_ptr<llama_kv_cache_unified> kv_base;
+    std::unique_ptr<llama_kv_cache_unified> kv_swa;
+};
 
-    struct {
-        struct entry {
-            llama_pos pmin;
-            llama_pos pmax;
-        };
+class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
+public:
+    // used for errors
+    llama_kv_cache_unified_iswa_state(llama_memory_status status);
 
-        void clear() {
-            pos.clear();
-        }
+    // used to create a full-cache state
+    llama_kv_cache_unified_iswa_state(
+            llama_memory_status status,
+            llama_kv_cache_unified_iswa * kv);
 
-        // used to perform SWA pruning of old tokens
-        std::unordered_map<llama_seq_id, entry> pos;
-    } pending;
+    // used to create a state from a batch
+    llama_kv_cache_unified_iswa_state(
+            llama_memory_status status,
+            llama_kv_cache_unified_iswa * kv,
+            llama_sbatch sbatch,
+            std::vector<uint32_t> heads_base,
+            std::vector<uint32_t> heads_swa,
+            std::vector<llama_ubatch> ubatches);
 
-    std::unique_ptr<llama_kv_cache_unified> kv_base;
-    std::unique_ptr<llama_kv_cache_unified> kv_swa;
+    virtual ~llama_kv_cache_unified_iswa_state();
+
+    //
+    // llama_memory_state_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_unified_iswa_state specific API
+    //
+
+    const llama_kv_cache_unified_state * get_base() const;
+    const llama_kv_cache_unified_state * get_swa()  const;
+
+private:
+    const llama_memory_status status;
+
+    //llama_kv_cache_unified_iswa * kv;
+
+    llama_sbatch sbatch;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    std::unique_ptr<llama_kv_cache_unified_state> state_base;
+    std::unique_ptr<llama_kv_cache_unified_state> state_swa;
 };
 
 //
 // llama_kv_cache_recurrent
 //
 
+// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
+//       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
 class llama_kv_cache_recurrent : public llama_kv_cache {
 public:
-    struct kv_cell {
-        llama_pos pos  = -1;
-        int32_t   src  = -1; // used to copy states
-        int32_t   tail = -1;
-
-        std::set<llama_seq_id> seq_id;
-
-        bool has_seq_id(const llama_seq_id & id) const {
-            return seq_id.find(id) != seq_id.end();
-        }
-
-        bool is_empty() const {
-            return seq_id.empty();
-        }
-
-        bool is_same_seq(const kv_cell & other) const {
-            return seq_id == other.seq_id;
-        }
-    };
-
     llama_kv_cache_recurrent(
             const llama_model & model,
                     ggml_type   type_k,
@@ -428,19 +489,22 @@ public:
     // llama_kv_cache
     //
 
-    void restore() override;
-    void commit()  override;
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) override;
 
-    bool update(llama_context & ctx) override;
+    llama_memory_state_ptr init_full() override;
 
-    void defrag_sched(float thold) override;
+    bool update(llama_context & lctx) override;
 
-    void set_full() override;
+    void defrag_sched(float thold) override;
 
-    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+    bool prepare(const std::vector<llama_ubatch> & ubatches);
 
-    bool find_slot(const llama_ubatch & batch) override;
+    // find a contiguous slot of kv cells and emplace the ubatch there
+    bool find_slot(const llama_ubatch & ubatch);
 
     bool get_can_shift() const override;
 
@@ -460,6 +524,27 @@ public:
     // computed before each graph build
     uint32_t n = 0;
 
+    // TODO: optimize for recurrent state needs
+    struct kv_cell {
+        llama_pos pos  = -1;
+        int32_t   src  = -1; // used to copy states
+        int32_t   tail = -1;
+
+        std::set<llama_seq_id> seq_id;
+
+        bool has_seq_id(const llama_seq_id & id) const {
+            return seq_id.find(id) != seq_id.end();
+        }
+
+        bool is_empty() const {
+            return seq_id.empty();
+        }
+
+        bool is_same_seq(const kv_cell & other) const {
+            return seq_id == other.seq_id;
+        }
+    };
+
     std::vector<kv_cell> cells;
 
     std::vector<ggml_tensor *> k_l; // per layer
@@ -469,26 +554,11 @@ private:
     //const llama_model & model;
     const llama_hparams & hparams;
 
-    // commit/restore cache
-    // TODO: rework for recurrent cache
-    struct slot_range {
-        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
-        uint32_t c1 = 0;
-    };
-
-    // pending cell updates that are not yet committed
-    struct {
-        std::vector<slot_range> ranges;
-    } pending;
-
     const uint32_t n_seq_max = 1;
 
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    // find how many cells are currently in use
-    uint32_t cell_max() const;
-
     size_t total_size() const;
 
     size_t size_k_bytes() const;
@@ -500,3 +570,67 @@ private:
     bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
+
+class llama_kv_cache_recurrent_state : public llama_memory_state_i {
+public:
+    // used for errors
+    llama_kv_cache_recurrent_state(llama_memory_status status);
+
+    // used to create a full-cache state
+    llama_kv_cache_recurrent_state(
+            llama_memory_status status,
+            llama_kv_cache_recurrent * kv);
+
+    // used to create a state from a batch
+    llama_kv_cache_recurrent_state(
+            llama_memory_status status,
+            llama_kv_cache_recurrent * kv,
+            llama_sbatch sbatch,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_recurrent_state();
+
+    //
+    // llama_memory_state_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_recurrent_state specific API
+    //
+
+    uint32_t get_n_kv() const;
+    uint32_t get_head() const;
+    uint32_t get_size() const;
+
+    ggml_tensor * get_k_l(int32_t il) const;
+    ggml_tensor * get_v_l(int32_t il) const;
+
+    int32_t s_copy(int i) const;
+    float   s_mask(int i) const;
+
+private:
+    const llama_memory_status status;
+
+    llama_kv_cache_recurrent * kv;
+
+    llama_sbatch sbatch;
+
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    //
+    // data needed for building the compute graph for the current ubatch:
+    // TODO: extract all the state like `head` and `n` here
+    //
+
+    const bool is_full = false;
+};
index dbbd03fcba2817834f8b82eab42f981dc3fec759..9e2c4d927699d72300ed013b17135a9b12c63bdb 100644 (file)
@@ -68,12 +68,6 @@ public:
     // the index of the last cell that is used + 1
     // return 0 if no cells are used
     uint32_t used_max_p1() const {
-#if 0
-        if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
-        if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
-        if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
-#endif
-
         return used.empty() ? 0 : *used.rbegin() + 1;
     }
 
@@ -144,6 +138,19 @@ public:
         }
     }
 
+    // clear a non-empty cell
+    void rm(uint32_t i) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        seq_pos_rm(i);
+
+        pos[i] = -1;
+        seq[i].reset();
+
+        used.erase(i);
+    }
+
     // note: call only if the cell has seq_id
     // return true if the cell becomes empty
     bool seq_rm(uint32_t i, llama_seq_id seq_id) {
@@ -196,6 +203,15 @@ public:
         return false;
     }
 
+    // number of different sequences in the cell
+    int seq_count(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return seq[i].count();
+    }
+
+    // check if the cell contains seq_id
     bool seq_has(uint32_t i, llama_seq_id seq_id) const {
         assert(i < pos.size());
         assert(seq_id >= 0);
@@ -213,6 +229,20 @@ public:
         seq_pos[seq_id].insert(pos[i]);
     }
 
+    // return the sequence id of this cell
+    // note: call only for cells with exactly one sequence
+    llama_seq_id seq_get(uint32_t i) const {
+        assert(seq[i].count() == 1);
+
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (seq[i].test(s)) {
+                return s;
+            }
+        }
+
+        return -1;
+    }
+
     // the minimum position of sequence seq_id currently present in any of the cells
     // return -1 if the sequence is not present
     llama_pos seq_pos_min(llama_seq_id seq_id) const {
@@ -268,6 +298,7 @@ public:
     void pos_set(uint32_t i, llama_pos p) {
         assert(i < pos.size());
         assert(pos[i] == -1);
+        assert(seq[i].none());
 
         pos[i] = p;
 
index a2d250434affa8c58fe6cb7639279b34b5397b67..b3799d66e8c170241284abbb4735cc729a553512 100644 (file)
@@ -2,6 +2,11 @@
 
 #include "llama.h"
 
+#include <memory>
+#include <vector>
+
+struct llama_ubatch;
+
 struct llama_memory_params {
     // kv cache
     ggml_type type_k;
@@ -30,3 +35,42 @@ public:
 
     virtual bool get_can_edit() const = 0;
 };
+
+enum llama_memory_status {
+    LLAMA_MEMORY_STATUS_SUCCESS = 0,
+    LLAMA_MEMORY_STATUS_FAILED_PREPARE,
+    LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
+};
+
+// the interface for managing the memory state during batch processing
+// this interface is implemented per memory type. see:
+//   - llama_kv_cache_unified_state
+//   - llama_kv_cache_unified_iswa_state
+//   ...
+//
+// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
+//
+// TODO: rename to llama_memory_context_i ?
+class llama_memory_state_i {
+public:
+    virtual ~llama_memory_state_i() = default;
+
+    // consume the current ubatch from the state and proceed to the next one
+    // return false if we are done
+    virtual bool next() = 0;
+
+    // apply the memory state for the current ubatch to the memory object
+    // return false on failure
+    virtual bool apply() = 0;
+
+    // TODO: this might get reworked in the future when refactoring llama_batch
+    virtual std::vector<int64_t> & out_ids() = 0;
+
+    // get the current ubatch
+    virtual const llama_ubatch & get_ubatch() const = 0;
+
+    // get the status of the memory state
+    virtual llama_memory_status get_status() const = 0;
+};
+
+using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
index 3f1f6c9bf3b067536efe8ab7a04594aa543a4c0e..e85becbb8f6958118acf72ae6e995871da094d6b 100644 (file)
@@ -8892,9 +8892,9 @@ struct llm_build_mamba : public llm_graph_context {
              ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const {
-        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
-        const auto kv_head = kv_self->head;
+        const auto kv_head = kv_state->get_head();
 
         const int64_t d_conv  = hparams.ssm_d_conv;
         const int64_t d_inner = hparams.ssm_d_inner;
@@ -8912,8 +8912,8 @@ struct llm_build_mamba : public llm_graph_context {
         GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        ggml_tensor * conv_states_all = kv_self->k_l[il];
-        ggml_tensor * ssm_states_all  = kv_self->v_l[il];
+        ggml_tensor * conv_states_all = kv_state->get_k_l(il);
+        ggml_tensor * ssm_states_all  = kv_state->get_v_l(il);
 
         // (ab)using the KV cache to store the states
         ggml_tensor * conv = build_copy_mask_state(
@@ -11640,7 +11640,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * state_mask,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -11650,7 +11650,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         const auto n_head = n_embd / head_size;
         const auto n_head_kv = hparams.n_head_kv(il);
 
-        const auto kv_head = kv_self->head;
+        const auto kv_head = kv_state->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -11762,7 +11762,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         }
 
         ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_self->v_l[il], state_copy, state_mask,
+                gf, kv_state->get_v_l(il), state_copy, state_mask,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -11781,9 +11781,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_self->v_l[il],
+                        kv_state->get_v_l(il),
                         hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
                         )
                     )
                 );
@@ -12036,7 +12036,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
+        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12045,7 +12045,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         const auto head_count = n_embd / head_size;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
 
-        const auto kv_head = kv_self->head;
+        const auto kv_head = kv_state->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -12116,7 +12116,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
         ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_self->v_l[il], state_copy, state_mask,
+                gf, kv_state->get_v_l(il), state_copy, state_mask,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12130,9 +12130,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_self->v_l[il],
+                        kv_state->get_v_l(il),
                         hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
                         )
                     )
                 );
index 5d03dc3dc790ac5ccfdd82e0f0b2fb38419600ef..90981ff9a5ef75835f0d284158bbd8f4e004d9bc 100644 (file)
@@ -3214,9 +3214,12 @@ struct server_context {
                             }
 
                             if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
-                                if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) {
+                                const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
+                                if (pos_min > 0) {
+                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
                                     SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
                                             "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+                                    llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
                                     slot.n_past = 0;
                                 }
                             }
@@ -3379,8 +3382,10 @@ struct server_context {
             }
         }
 
+        int32_t i_next = 0;
+
         // process the created batch of tokens
-        for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
+        for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 
             llama_batch batch_view = {
@@ -3425,13 +3430,18 @@ struct server_context {
 
                 // retry with half the batch size to try to find a free slot in the KV cache
                 n_batch /= 2;
-                i -= n_batch;
 
                 SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
 
                 continue; // continue loop of n_batch
             }
 
+            // move the head of the batch forward with the number of tokens we just processed
+            i_next = i + n_tokens;
+
+            // on successful decode, restore the original batch size
+            n_batch = llama_n_batch(ctx);
+
             for (auto & slot : slots) {
                 if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
                     continue; // continue loop of slots