]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : refactor the update/defrag mechanism (#13988)
authorGeorgi Gerganov <redacted>
Wed, 4 Jun 2025 15:58:20 +0000 (18:58 +0300)
committerGitHub <redacted>
Wed, 4 Jun 2025 15:58:20 +0000 (18:58 +0300)
* kv-cache : refactor update mechanism

ggml-ci

* memory : improve status handling

* defrag : reset head + add comments

ggml-ci

* cont : minor fixes

ggml-ci

src/llama-context.cpp
src/llama-context.h
src/llama-kv-cache-recurrent.cpp
src/llama-kv-cache-recurrent.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-kv-cache.h
src/llama-memory.cpp
src/llama-memory.h

index 4ab57438794005fb3416b792156994a48b70d448..7c1a642c19464f62b940bf9492a37dec0271e262 100644 (file)
@@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const {
     return kv_self;
 }
 
-bool llama_context::kv_self_update() {
+void llama_context::kv_self_defrag_sched() {
+    if (!memory) {
+        return;
+    }
+
+    memory_force_optimize = true;
+}
+
+bool llama_context::kv_self_update(bool optimize) {
     if (!memory) {
         return false;
     }
 
     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-    if (!kv_self->update(*this)) {
-        // no updates have been performed
-        return false;
+    {
+        // TODO: remove in the future
+        optimize |= memory_force_optimize;
+        memory_force_optimize = false;
+
+        const auto kv_state = kv_self->init_update(this, optimize);
+        switch (kv_state->get_status()) {
+            case LLAMA_MEMORY_STATUS_SUCCESS:
+                {
+                    // noop
+                } break;
+            case LLAMA_MEMORY_STATUS_NO_UPDATE:
+                {
+                    // no updates need to be performed
+                    return false;
+                }
+            case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+            case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+                {
+                    LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
+                    return false;
+                }
+        }
+
+        if (!kv_state->apply()) {
+            LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
+        }
     }
 
     // if the KV cache did any computation, we have to reserve a new worst-case graph
     const auto kv_state = kv_self->init_full();
     if (!kv_state) {
-        throw std::runtime_error("failed to initialize KV cache");
+        throw std::runtime_error("failed to initialize memory state");
     }
 
     const uint32_t n_seqs   = cparams.n_seq_max;
@@ -452,7 +484,7 @@ bool llama_context::kv_self_update() {
 
     auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
     if (!gf) {
-        LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
+        LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
     }
 
     return true;
@@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) {
         n_outputs_all = 1;
     }
 
+    bool did_optimize = false;
+
     // handle any pending defrags/shifts
-    kv_self_update();
+    kv_self_update(false);
 
     llama_memory_state_ptr kv_state;
 
-    bool did_defrag = false;
-
     while (true) {
         kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
         if (!kv_state) {
@@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                 } break;
+            case LLAMA_MEMORY_STATUS_NO_UPDATE:
+                {
+                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
+
+                    return -2;
+                }
             case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
                 {
-                    if (!did_defrag) {
-                        did_defrag = true;
+                    if (!did_optimize) {
+                        did_optimize = true;
 
-                        kv_self->defrag_sched(-1.0f);
-                        if (kv_self_update()) {
-                            LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
+                        if (kv_self_update(true)) {
+                            LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
 
                             continue;
                         }
                     }
 
-                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+                    LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
 
                     return 1;
                 }
             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
                 {
+                    LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
+
                     return -2;
                 }
         }
@@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    // decide if we need to defrag the kv cache
-    if (cparams.defrag_thold > 0.0f) {
-        kv_self->defrag_sched(cparams.defrag_thold);
-    }
-
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
     ggml_backend_sched_reset(sched.get());
@@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
 
 // deprecated
 void llama_kv_self_update(llama_context * ctx) {
-    ctx->kv_self_update();
+    ctx->kv_self_update(false);
 }
 
 enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
 
 // deprecated
 void llama_kv_self_defrag(llama_context * ctx) {
-    auto * kv = ctx->get_kv_self();
-    if (!kv) {
-        return;
-    }
-
     // force defrag
-    kv->defrag_sched(-1.0f);
+    ctx->kv_self_defrag_sched();
 }
 
 bool llama_kv_self_can_shift(const llama_context * ctx) {
index 3b880286bfd5de755467aea8eef849a3f13df766..c1c7efb31fe32d0e8305ee1db9f678f47fbbc571 100644 (file)
@@ -52,7 +52,8 @@ struct llama_context {
 
     // return true of the KV cache was updated
     // TODO: remove
-    bool kv_self_update();
+    bool kv_self_update(bool optimize);
+    void kv_self_defrag_sched();
 
     enum llama_pooling_type pooling_type() const;
 
@@ -231,6 +232,9 @@ private:
 
     std::unique_ptr<llama_memory_i> memory;
 
+    // TODO: temporary, until the llama_kv_self_defrag() API is removed
+    bool memory_force_optimize = false;
+
     // decode output (2-dimensional array: [n_outputs][n_vocab])
     size_t  logits_size = 0; // capacity (of floats) for logits
     float * logits      = nullptr;
index 641eab2f316cef56d8be36c472118b3099e642e0..77bd57065ac3dc6b2ac73431f167bc5cfeda9c86 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-kv-cache-recurrent.h"
 
 #include "llama-impl.h"
+#include "llama-io.h"
 #include "llama-batch.h"
 #include "llama-model.h"
 
@@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
     return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
 }
 
+llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
+    GGML_UNUSED(lctx);
+    GGML_UNUSED(optimize);
+
+    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+}
+
 bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
     // simply remember the full state because it is very small for this type of cache
     // TODO: optimize
@@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
     return success;
 }
 
-bool llama_kv_cache_recurrent::update(llama_context & lctx) {
-    GGML_UNUSED(lctx);
-    // noop
-    return false;
-}
-
-void llama_kv_cache_recurrent::defrag_sched(float thold) {
-    GGML_UNUSED(thold);
-    // noop
-}
-
 bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
     const uint32_t n_seqs   = ubatch.n_seqs;
index a178ae85c146a8e8f36579668aa2de3e67012cca..b32f258fbffbf9fd2a02e8b0b56f437ad5ddae5d 100644 (file)
@@ -52,9 +52,7 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool prepare(const std::vector<llama_ubatch> & ubatches);
 
index 0eb04563435467b5845668f33bf3f15b96adc47c..3aa606c849843086f136329dd3a7177fe25c916c 100644 (file)
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
 
     assert(heads_base.size() == heads_swa.size());
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(
             this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
 }
 
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
 }
 
-bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
-    bool res = false;
-
-    res = res | kv_base->update(lctx);
-    res = res | kv_swa ->update(lctx);
-
-    return res;
-}
-
-void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
-    kv_base->defrag_sched(thold);
-    kv_swa ->defrag_sched(thold);
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
 }
 
 bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_memory_status status,
-        llama_kv_cache_unified_iswa * kv) : status(status) {
-    state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
-    state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
+        llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+    state_base = kv->get_base()->init_full();
+    state_swa  = kv->get_swa ()->init_full();
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+        llama_kv_cache_unified_iswa * kv,
+        llama_context * lctx,
+        bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+    state_base = kv->get_base()->init_update(lctx, optimize);
+    state_swa  = kv->get_swa ()->init_update(lctx, optimize);
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_memory_status status,
         llama_kv_cache_unified_iswa * kv,
         llama_sbatch sbatch,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
         std::vector<llama_ubatch> ubatches)
-    : status(status),
-    sbatch(std::move(sbatch)),
-    ubatches(std::move(ubatches)) {
-        // note: here we copy the ubatches. not sure if this is ideal
-        state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
-        state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
-    }
+        : status(LLAMA_MEMORY_STATUS_SUCCESS),
+        sbatch(std::move(sbatch)),
+        ubatches(std::move(ubatches)) {
+    // note: here we copy the ubatches. not sure if this is ideal
+    state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
+    state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
 
 llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
 
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
 
 const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
     return ubatches[i_next];
 }
 
 const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return state_base.get();
+    return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
 }
 
 const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return state_swa.get();
+    return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
 }
index 8b067da038af6af2e8322aa98f2920f3f23d369a..cba5bbe95197e66a4a6808814df6d8e1f08b6b76 100644 (file)
@@ -54,9 +54,7 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -86,12 +84,16 @@ public:
 
     // used to create a full-cache state
     llama_kv_cache_unified_iswa_state(
-            llama_memory_status status,
             llama_kv_cache_unified_iswa * kv);
 
+    // used to create an update state
+    llama_kv_cache_unified_iswa_state(
+            llama_kv_cache_unified_iswa * kv,
+            llama_context * lctx,
+            bool optimize);
+
     // used to create a state from a batch
     llama_kv_cache_unified_iswa_state(
-            llama_memory_status status,
             llama_kv_cache_unified_iswa * kv,
             llama_sbatch sbatch,
             std::vector<uint32_t> heads_base,
@@ -120,7 +122,7 @@ public:
     const llama_kv_cache_unified_state * get_swa()  const;
 
 private:
-    const llama_memory_status status;
+    llama_memory_status status;
 
     //llama_kv_cache_unified_iswa * kv;
 
@@ -131,6 +133,6 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    std::unique_ptr<llama_kv_cache_unified_state> state_base;
-    std::unique_ptr<llama_kv_cache_unified_state> state_swa;
+    llama_memory_state_ptr state_base;
+    llama_memory_state_ptr state_swa;
 };
index 4007f202e313b9cb9186533532ae3c1f92481081..5354f808cd9595cde90a5f8ae8ab016e3c1212e8 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-kv-cache-unified.h"
 
 #include "llama-impl.h"
+#include "llama-io.h"
 #include "llama-model.h"
 #include "llama-context.h"
 
@@ -320,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
         return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+    return std::make_unique<llama_kv_cache_unified_state>(
             this, std::move(sbatch), std::move(heads), std::move(ubatches));
 }
 
 llama_memory_state_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+    return std::make_unique<llama_kv_cache_unified_state>(this);
 }
 
-std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    std::vector<uint32_t> res;
+llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+    bool do_shift = get_has_shift();
+
+    defrag_info dinfo;
+
+    // see if we need to defrag
+    {
+        bool do_defrag = optimize;
+
+        const auto thold = lctx->get_cparams().defrag_thold;
+
+        if (!do_defrag && thold > 0.0f) {
+            const auto n_kv = cells.used_max_p1();
+
+            // - do not defrag small contexts (i.e. < 2048 tokens)
+            // - count the padding towards the number of used tokens
+            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+
+            if (fragmentation > thold) {
+                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+
+                do_defrag = true;
+            }
+        }
+
+        if (do_defrag) {
+            dinfo = defrag_prepare(lctx->graph_max_nodes());
+        }
+    }
+
+    return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
+}
+
+llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache_unified::ubatch_heads res;
 
     struct state {
         uint32_t head_old; // old position of the head, before placing the ubatch
@@ -374,12 +408,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
     return res;
 }
 
-bool llama_kv_cache_unified::update(llama_context & lctx) {
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
     bool updated = false;
 
-    auto * sched = lctx.get_sched();
+    auto * sched = lctx->get_sched();
 
-    if (cells.get_has_shift()) {
+    if (do_shift) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
         }
@@ -390,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
             ggml_backend_sched_reset(sched);
 
-            auto * gf = lctx.graph_init();
+            auto * gf = lctx->graph_init();
 
-            auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
             if (!res) {
                 LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
                 return updated;
@@ -405,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
             res->set_inputs(nullptr);
 
-            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
                 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
                 return updated;
             }
@@ -416,54 +450,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
         cells.reset_shift();
     }
 
-    if (do_defrag) {
+    if (!dinfo.empty()) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
-        if (defrag_prepare(lctx.graph_max_nodes())) {
-            ggml_backend_sched_reset(sched);
-
-            auto * gf = lctx.graph_init();
-
-            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
-            if (!res) {
-                LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
-                return updated;
-            }
+        // apply moves:
+        {
+            const auto n_kv = dinfo.ids.size();
 
-            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
-                return updated;
-            }
+            for (uint32_t i = 0; i < n_kv; ++i) {
+                assert(dinfo.ids[i] <= n_kv);
 
-            res->set_inputs(nullptr);
+                if (dinfo.ids[i] == n_kv) {
+                    continue;
+                }
 
-            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-                LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
-                return updated;
+                cells.mv(i, dinfo.ids[i]);
             }
 
-            updated = true;
+            // reset the head so we can find the first free slot during the next ubatch
+            head = 0;
         }
 
-        do_defrag = false;
-    }
+        ggml_backend_sched_reset(sched);
 
-    return updated;
-}
+        auto * gf = lctx->graph_init();
 
-void llama_kv_cache_unified::defrag_sched(float thold) {
-    const auto n_kv = cells.used_max_p1();
+        auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
+        if (!res) {
+            LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
+            return updated;
+        }
+
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
+            return updated;
+        }
 
-    // - do not defrag small contexts (i.e. < 2048 tokens)
-    // - count the padding towards the number of used tokens
-    const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+        res->set_inputs(nullptr);
 
-    // queue defragmentation for next llama_kv_cache_update
-    if (fragmentation > thold) {
-        LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
+            return updated;
+        }
 
-        do_defrag = true;
+        updated = true;
     }
+
+    return updated;
 }
 
 int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
@@ -612,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
     return cells.size();
 }
 
+bool llama_kv_cache_unified::get_has_shift() const {
+    return cells.get_has_shift();
+}
+
 uint32_t llama_kv_cache_unified::get_n_kv() const {
     return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
 }
@@ -941,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 }
 
 llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_cgraph * gf) const {
+                const llama_cparams & cparams,
+                       ggml_context * ctx,
+                        ggml_cgraph * gf,
+                  const defrag_info & dinfo) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & ids = defrag_info.ids;
+    const auto & ids = dinfo.ids;
 
 #if 0
     // CPU defrag
@@ -1087,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
     return res;
 }
 
-bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
     const uint32_t n_layer = layers.size();
 
     const uint32_t n_kv   = cells.used_max_p1();
@@ -1108,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
-    //
-    //  cell i moves to ids[i]
-    //
-    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
-    //
-    auto & ids = defrag_info.ids;
+    defrag_info res;
+    auto & ids = res.ids;
 
-    ids.clear();
     ids.resize(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@@ -1179,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
             // this cell goes to (i0 + nf)
             ids[i1] = i0 + nf;
 
-            // move the cell meta data
-            cells.mv(i1, i0 + nf);
-
-            head = n_used;
-
             if (!cont) {
                 n_moves++;
                 cont = true;
@@ -1206,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     }
 
     if (n_moves == 0) {
-        return false;
+        return {};
     }
 
     LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
 
     LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
 
-    return true;
+    return res;
 }
 
 bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@@ -1636,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(
-            llama_memory_status status,
-            llama_kv_cache_unified * kv) : status(status), kv(kv) {
-        n_kv = kv->get_size();
-        head = 0;
-    }
+        llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+    n_kv = kv->get_size();
+    head = 0;
+}
 
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(
-            llama_memory_status status,
-            llama_kv_cache_unified * kv,
-            llama_sbatch sbatch,
-            std::vector<uint32_t> heads,
-            std::vector<llama_ubatch> ubatches)
-            : status(status),
-              kv(kv),
-              sbatch(std::move(sbatch)),
-              heads(std::move(heads)),
-              ubatches(std::move(ubatches)) {
+        llama_kv_cache_unified * kv,
+        llama_context * lctx,
+        bool do_shift,
+        defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
+    if (!do_shift && dinfo.empty()) {
+        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
+}
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+        llama_kv_cache_unified * kv,
+        llama_sbatch sbatch,
+        llama_kv_cache_unified::ubatch_heads heads,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
+}
 
 llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
 
@@ -1670,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
 bool llama_kv_cache_unified_state::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
+    // no ubatches -> this is a KV cache update
+    if (ubatches.empty()) {
+        kv->update(lctx, do_shift, dinfo);
+
+        return true;
+    }
+
     kv->apply_ubatch(heads[i_next], ubatches[i_next]);
 
     n_kv = kv->get_n_kv();
index 1f1d44b97c2ac713e54bdeb0f63ec40a7d9f275e..6ff388a88b1d58690d81215f18158e7d4c9a8589 100644 (file)
@@ -24,6 +24,19 @@ public:
     // this callback is used to filter out layers that should not be included in the cache
     using layer_filter_cb = std::function<bool(int32_t il)>;
 
+    using ubatch_heads = std::vector<uint32_t>;
+
+    struct defrag_info {
+        bool empty() const {
+            return ids.empty();
+        }
+
+        // contains information about which cell moves where:
+        //  - cell i moves to ids[i]
+        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
+        std::vector<uint32_t> ids;
+    };
+
     llama_kv_cache_unified(
             const llama_model &  model,
               layer_filter_cb && filter,
@@ -66,9 +79,7 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -83,6 +94,8 @@ public:
 
     uint32_t get_size() const;
 
+    bool get_has_shift() const;
+
     //
     // graph_build API
     //
@@ -103,7 +116,9 @@ public:
 
     // find places for the provided ubatches in the cache, returns the head locations
     // return empty vector on failure
-    std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
+    ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
+
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
 
     // return the cell position where we can insert the ubatch
     // return -1 on failure to find a contiguous slot of kv cells
@@ -133,8 +148,7 @@ private:
         ggml_tensor * v;
     };
 
-    bool do_defrag = false;
-    bool v_trans   = true;  // the value tensor is transposed
+    bool v_trans = true;  // the value tensor is transposed
 
     // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -160,13 +174,8 @@ private:
     // model layer id -> KV cache layer id
     std::unordered_map<int32_t, int32_t> map_layer_ids;
 
-    // defrag
-    struct {
-        std::vector<uint32_t> ids;
-    } defrag_info;
-
-    // return true if cells have been moved
-    bool defrag_prepare(int32_t n_max_nodes);
+    // return non-empty vector if cells have been moved
+    defrag_info defrag_prepare(int32_t n_max_nodes) const;
 
     size_t total_size() const;
 
@@ -192,7 +201,8 @@ private:
     llm_graph_result_ptr build_graph_defrag(
             const llama_cparams & cparams,
                    ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+                    ggml_cgraph * gf,
+              const defrag_info & dinfo) const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -203,20 +213,29 @@ private:
 
 class llama_kv_cache_unified_state : public llama_memory_state_i {
 public:
+    // some shorthands
+    using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
+    using defrag_info  = llama_kv_cache_unified::defrag_info;
+
     // used for errors
     llama_kv_cache_unified_state(llama_memory_status status);
 
     // used to create a full-cache state
     llama_kv_cache_unified_state(
-            llama_memory_status status,
             llama_kv_cache_unified * kv);
 
-    // used to create a state from a batch
+    // used to create an update state
+    llama_kv_cache_unified_state(
+            llama_kv_cache_unified * kv,
+            llama_context * lctx,
+            bool do_shift,
+            defrag_info dinfo);
+
+    // used to create a decode state from a batch
     llama_kv_cache_unified_state(
-            llama_memory_status status,
             llama_kv_cache_unified * kv,
             llama_sbatch sbatch,
-            std::vector<uint32_t> heads,
+            ubatch_heads heads,
             std::vector<llama_ubatch> ubatches);
 
     virtual ~llama_kv_cache_unified_state();
@@ -253,16 +272,30 @@ public:
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
-    const llama_memory_status status;
+    llama_memory_status status;
 
     llama_kv_cache_unified * kv;
+    llama_context * lctx;
+
+    //
+    // update state
+    //
+
+    bool do_shift = false;
+
+    defrag_info dinfo;
+
+    //
+    // batch processing state
+    //
 
     llama_sbatch sbatch;
 
     // the index of the next ubatch to process
     size_t i_next = 0;
 
-    std::vector<uint32_t> heads;
+    ubatch_heads heads;
+
     std::vector<llama_ubatch> ubatches;
 
     //
index 2d04705f27857ea078761f3f065bf0bf1f3674f2..17a5e5cb84903c5c0482d1ced45d694672c53e9c 100644 (file)
@@ -1,12 +1,16 @@
 #pragma once
 
 #include "llama.h"
-#include "llama-io.h"
 #include "llama-memory.h"
 
+class llama_io_write_i;
+class llama_io_read_i;
+
 struct llama_kv_cache : public llama_memory_i {
     virtual ~llama_kv_cache() = default;
 
+    // TODO: move the init_ interfaces to llama_memory_i
+
     // split the input batch into a set of ubatches and verify that they can fit into the cache
     // return a state object containing the ubatches and KV cache state required to process them
     // check the llama_memory_state_i::get_status() for the result
@@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i {
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_state_ptr init_full() = 0;
 
-    // process any pending defrag/shift/etc. operations
-    // optionally call once before processing a new batch
-    // return true if any operations were performed
-    virtual bool update(llama_context & lctx) = 0;
-
-    // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
-    // TODO: change to
-    //   llama_memory_state_ptr init_defrag(float thold) = 0;
-    //
-    virtual void defrag_sched(float thold) = 0;
+    // prepare for any pending memory updates, such as shifts, defrags, etc.
+    // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
 
     // getters
     virtual bool get_can_shift() const = 0;
index 10173253edfe448ad9aeba27be3a9deee484bff7..f1107672c6476411b04521db02379255328e7728 100644 (file)
@@ -1 +1,42 @@
 #include "llama-memory.h"
+
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
+    bool has_update = false;
+
+    switch (s0) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+            {
+                has_update = true;
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return s0;
+            }
+    }
+
+    switch (s1) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+            {
+                has_update = true;
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return s1;
+            }
+    }
+
+    // if either status has an update, then the combined status has an update
+    return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
+}
index b3799d66e8c170241284abbb4735cc729a553512..ab0d399c4eb602ebe2e6af74afd6f14b1a18aaf3 100644 (file)
@@ -36,12 +36,19 @@ public:
     virtual bool get_can_edit() const = 0;
 };
 
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
+
 enum llama_memory_status {
     LLAMA_MEMORY_STATUS_SUCCESS = 0,
+    LLAMA_MEMORY_STATUS_NO_UPDATE,
     LLAMA_MEMORY_STATUS_FAILED_PREPARE,
     LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
 };
 
+// helper function for combining the status of two memory states
+// useful for implementing hybrid memory types (e.g. iSWA)
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
+
 // the interface for managing the memory state during batch processing
 // this interface is implemented per memory type. see:
 //   - llama_kv_cache_unified_state
@@ -69,7 +76,7 @@ public:
     // get the current ubatch
     virtual const llama_ubatch & get_ubatch() const = 0;
 
-    // get the status of the memory state
+    // get the status of the memory state - used for error handling and checking if any updates would be applied
     virtual llama_memory_status get_status() const = 0;
 };