]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : drop the "unified" prefix (#15467)
authorGeorgi Gerganov <redacted>
Thu, 21 Aug 2025 14:00:33 +0000 (17:00 +0300)
committerGitHub <redacted>
Thu, 21 Aug 2025 14:00:33 +0000 (17:00 +0300)
* kv-cache : drop the "unified" prefix

ggml-ci

* cont : fix comment [no ci]

19 files changed:
include/llama.h
src/CMakeLists.txt
src/llama-context.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache-iswa.cpp [new file with mode: 0644]
src/llama-kv-cache-iswa.h [new file with mode: 0644]
src/llama-kv-cache-unified-iswa.cpp [deleted file]
src/llama-kv-cache-unified-iswa.h [deleted file]
src/llama-kv-cache-unified.cpp [deleted file]
src/llama-kv-cache-unified.h [deleted file]
src/llama-kv-cache.cpp [new file with mode: 0644]
src/llama-kv-cache.h [new file with mode: 0644]
src/llama-kv-cells.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-memory-recurrent.h
src/llama-memory.h
src/llama-model.cpp

index 135eaf1b655695e8e407dedd65ff2c6fe4db53dd..c465ced4ffa01b2b84691db535ec913d05d5e087 100644 (file)
@@ -64,8 +64,6 @@ extern "C" {
 
     typedef struct llama_memory_i * llama_memory_t;
 
-    struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
-
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
     typedef int32_t llama_seq_id;
@@ -469,8 +467,6 @@ extern "C" {
     LLAMA_API           llama_memory_t   llama_get_memory  (const struct llama_context * ctx);
     LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
-    DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
-
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 
index 8f9cd652447abe72c07cf1699dfc9aaad38fd9ec..18cfc76564d3690f78ec9ff34e87d9bb2379ef3f 100644 (file)
@@ -20,8 +20,8 @@ add_library(llama
             llama-hparams.cpp
             llama-impl.cpp
             llama-io.cpp
-            llama-kv-cache-unified.cpp
-            llama-kv-cache-unified-iswa.cpp
+            llama-kv-cache.cpp
+            llama-kv-cache-iswa.cpp
             llama-memory.cpp
             llama-memory-hybrid.cpp
             llama-memory-recurrent.cpp
index 1ebfc88ab651a48f5698ab445dae9e4186e886ac..fb6fbe982c663a56869d7af096c14bcdf3e57a70 100644 (file)
@@ -2338,11 +2338,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
     return &ctx->get_model();
 }
 
-// deprecated
-llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
-    return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
-}
-
 // deprecated
 void llama_kv_self_update(llama_context * ctx) {
     ctx->kv_self_update(false);
index 053c72d6dc8d187087865a1b18792b4dade405f9..04baf03ea04bef595383361c9d0095f7d01d9c3a 100644 (file)
@@ -4,8 +4,8 @@
 #include "llama-batch.h"
 #include "llama-cparams.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
                     const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
-                    // TODO: reimplement this like in llama_kv_cache_unified
+                    // TODO: reimplement this like in llama_kv_cache
                     if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
                         if (hparams.use_alibi) {
                             f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->set_input_v_idxs(self_v_idxs, ubatch);
 
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
+bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
     return res;
 }
 
-void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
+bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
     auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
 
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
+static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
            ggml_context * ctx0,
      const llama_ubatch & ubatch,
     const llama_hparams & hparams,
     const llama_cparams & cparams,
-    const llama_kv_cache_unified_context * mctx_cur) {
+    const llama_kv_cache_context * mctx_cur) {
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
 
     {
-        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
 
         const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     return inp;
 }
 
-llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
-    auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
+    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
 
-    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified * inp,
+        llm_graph_input_attn_kv * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified_iswa * inp,
+        llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 ggml_tensor * llm_graph_context::build_attn_with_sinks(
-        llm_graph_input_attn_kv_unified_iswa * inp,
+        llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
 // TODO: maybe separate the inner implementation into a separate function
 //       like with the non-sliding window equivalent
 //       once sliding-window hybrid caches are a thing.
-llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
+llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
 
     const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     }
 
     {
-        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
 
         const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
     }
 
-    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
     auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
-    auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
+    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 
     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
 
index 6ff49de3a1ce848ac7312bb32856bd0f4a73062f..6636fa256f65a0ceeb74ea9173a77f95726d4637 100644 (file)
@@ -19,8 +19,8 @@ struct llama_cparams;
 
 struct llama_memory_context_i;
 
-class llama_kv_cache_unified_context;
-class llama_kv_cache_unified_iswa_context;
+class llama_kv_cache_context;
+class llama_kv_cache_iswa_context;
 class llama_memory_recurrent_context;
 class llama_memory_hybrid_context;
 
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
+            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +161,7 @@ public:
 
     const llama_hparams hparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +257,17 @@ public:
     const llama_cparams cparams;
 };
 
-class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
+class llm_graph_input_attn_kv : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified(
+    llm_graph_input_attn_kv(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_context * mctx) :
+            const llama_kv_cache_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified() = default;
+    ~llm_graph_input_attn_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -290,20 +290,20 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
-class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
+class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified_iswa(
+    llm_graph_input_attn_kv_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa_context * mctx) :
+            const llama_kv_cache_iswa_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified_iswa() = default;
+    ~llm_graph_input_attn_kv_iswa() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -330,7 +330,7 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_iswa_context * mctx;
+    const llama_kv_cache_iswa_context * mctx;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +351,7 @@ public:
 class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 public:
     llm_graph_input_mem_hybrid(
-            std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
+            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
             std::unique_ptr<llm_graph_input_rs>              inp_rs,
             const llama_memory_hybrid_context *              mctx) :
         inp_attn(std::move(inp_attn)),
@@ -361,11 +361,11 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
-    std::unique_ptr<llm_graph_input_rs>              inp_rs;
+    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
+    std::unique_ptr<llm_graph_input_rs>      inp_rs;
 
-    llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
-    llm_graph_input_rs              * get_recr() const { return inp_rs.get(); }
+    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
+    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 
     const llama_memory_hybrid_context * mctx;
 };
@@ -703,10 +703,10 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
+    llm_graph_input_attn_kv * build_attn_inp_kv() const;
 
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -717,11 +717,11 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
+    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
 
     // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified_iswa * inp,
+            llm_graph_input_attn_kv_iswa * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -734,7 +734,7 @@ struct llm_graph_context {
 
     // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
     ggml_tensor * build_attn_with_sinks(
-            llm_graph_input_attn_kv_unified_iswa * inp,
+            llm_graph_input_attn_kv_iswa * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -765,7 +765,7 @@ struct llm_graph_context {
     //
 
     // TODO: move this implementation to llama_memory_recurrent.
-    //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
+    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
     //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
     //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
     //         `llama_memory_recurrent`
diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp
new file mode 100644 (file)
index 0000000..a11ee5a
--- /dev/null
@@ -0,0 +1,301 @@
+#include "llama-kv-cache-iswa.h"
+
+#include "llama-impl.h"
+#include "llama-batch.h"
+#include "llama-model.h"
+
+#include <algorithm>
+#include <cassert>
+
+//
+// llama_kv_cache_iswa
+//
+
+llama_kv_cache_iswa::llama_kv_cache_iswa(
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
+                     bool   swa_full,
+                     bool   unified,
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max,
+                 uint32_t   n_ubatch,
+                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
+    llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
+    llama_kv_cache::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
+
+    const uint32_t size_base = kv_size;
+
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
+
+    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
+    if (swa_full) {
+        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
+                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+
+        size_swa = size_base;
+    }
+
+    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
+
+    kv_base = std::make_unique<llama_kv_cache>(
+            model, std::move(filter_base), type_k, type_v,
+            v_trans, offload, unified, size_base, n_seq_max, n_pad,
+            0, LLAMA_SWA_TYPE_NONE);
+
+    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
+
+    kv_swa = std::make_unique<llama_kv_cache>(
+            model, std::move(filter_swa), type_k, type_v,
+            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
+            hparams.n_swa, hparams.swa_type);
+}
+
+void llama_kv_cache_iswa::clear(bool data) {
+    kv_base->clear(data);
+    kv_swa ->clear(data);
+}
+
+bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    bool res = true;
+
+    res = res & kv_base->seq_rm(seq_id, p0, p1);
+    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
+
+    return res;
+}
+
+void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
+    kv_base->seq_keep(seq_id);
+    kv_swa ->seq_keep(seq_id);
+}
+
+void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    kv_base->seq_add(seq_id, p0, p1, shift);
+    kv_swa ->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    kv_base->seq_div(seq_id, p0, p1, d);
+    kv_swa ->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
+    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
+    return kv_swa->seq_pos_min(seq_id);
+}
+
+llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
+    return kv_swa->seq_pos_max(seq_id);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+    GGML_UNUSED(embd_all);
+
+    // first try simple split
+    do {
+        if (!unified) {
+            // requires equal splits, so we skip the simple split
+            break;
+        }
+
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_simple(n_ubatch);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
+            break;
+        }
+
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
+            break;
+        }
+
+        assert(sinfos_base.size() == sinfos_swa.size());
+
+        return std::make_unique<llama_kv_cache_iswa_context>(
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+    } while (false);
+
+    // if it fails, try equal split
+    do {
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_equal(n_ubatch, !unified);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
+            break;
+        }
+
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
+            break;
+        }
+
+        assert(sinfos_base.size() == sinfos_swa.size());
+
+        return std::make_unique<llama_kv_cache_iswa_context>(
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+    } while (false);
+
+    // TODO: if we fail again, we should attempt different splitting strategies
+    //       but to do that properly, we first have to refactor the batches to be more flexible
+
+    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_iswa_context>(this);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
+}
+
+bool llama_kv_cache_iswa::get_can_shift() const {
+    return kv_base->get_size() == kv_swa->get_size();
+}
+
+void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_write(io, seq_id, flags);
+    }
+
+    kv_swa->state_write(io, seq_id, flags);
+}
+
+void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_read(io, seq_id, flags);
+    }
+
+    kv_swa->state_read(io, seq_id, flags);
+}
+
+llama_kv_cache * llama_kv_cache_iswa::get_base() const {
+    return kv_base.get();
+}
+
+llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
+    return kv_swa.get();
+}
+
+//
+// llama_kv_cache_iswa_context
+//
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv) :
+    ctx_base(kv->get_base()->init_full()),
+    ctx_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
+        llama_context * lctx,
+        bool optimize) :
+    ctx_base(kv->get_base()->init_update(lctx, optimize)),
+    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
+        slot_info_vec_t sinfos_base,
+        slot_info_vec_t sinfos_swa,
+        std::vector<llama_ubatch> ubatches) :
+    ubatches(std::move(ubatches)),
+    // note: here we copy the ubatches. not sure if this is ideal
+    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
+
+bool llama_kv_cache_iswa_context::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    ctx_base->next();
+    ctx_swa ->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_iswa_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    bool res = true;
+
+    res = res & ctx_base->apply();
+    res = res & ctx_swa ->apply();
+
+    return res;
+}
+
+llama_memory_status llama_kv_cache_iswa_context::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_next];
+}
+
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
+}
+
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
+}
diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h
new file mode 100644 (file)
index 0000000..dd673f1
--- /dev/null
@@ -0,0 +1,133 @@
+#pragma once
+
+#include "llama-kv-cache.h"
+
+#include <vector>
+
+//
+// llama_kv_cache_iswa
+//
+
+// utilizes two instances of llama_kv_cache
+//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
+
+class llama_kv_cache_iswa : public llama_memory_i {
+public:
+    llama_kv_cache_iswa(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                         bool   swa_full,
+                         bool  ,
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_ubatch,
+                     uint32_t   n_pad);
+
+    ~llama_kv_cache_iswa() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_context_ptr init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) override;
+
+    llama_memory_context_ptr init_full() override;
+
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
+
+    //
+    // llama_kv_cache_iswa specific API
+    //
+
+    llama_kv_cache * get_base() const;
+    llama_kv_cache * get_swa () const;
+
+private:
+    const llama_hparams & hparams;
+
+    const bool unified;
+
+    std::unique_ptr<llama_kv_cache> kv_base;
+    std::unique_ptr<llama_kv_cache> kv_swa;
+};
+
+class llama_kv_cache_iswa_context : public llama_memory_context_i {
+public:
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
+
+    // used for errors
+    llama_kv_cache_iswa_context(llama_memory_status status);
+
+    // used to create a full-cache context
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv);
+
+    // used to create an update context
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
+            llama_context * lctx,
+            bool optimize);
+
+    // used to create a batch processing context from a batch
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
+            slot_info_vec_t sinfos_base,
+            slot_info_vec_t sinfos_swa,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_iswa_context();
+
+    //
+    // llama_memory_context_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_iswa_context specific API
+    //
+
+    const llama_kv_cache_context * get_base() const;
+    const llama_kv_cache_context * get_swa()  const;
+
+private:
+    //llama_kv_cache_iswa * kv;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    const llama_memory_context_ptr ctx_base;
+    const llama_memory_context_ptr ctx_swa;
+
+    const llama_memory_status status;
+};
diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp
deleted file mode 100644 (file)
index 1e363ff..0000000
+++ /dev/null
@@ -1,301 +0,0 @@
-#include "llama-kv-cache-unified-iswa.h"
-
-#include "llama-impl.h"
-#include "llama-batch.h"
-#include "llama-model.h"
-
-#include <algorithm>
-#include <cassert>
-
-//
-// llama_kv_cache_unified_iswa
-//
-
-llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
-        const llama_model & model,
-                ggml_type   type_k,
-                ggml_type   type_v,
-                     bool   v_trans,
-                     bool   offload,
-                     bool   swa_full,
-                     bool   unified,
-                 uint32_t   kv_size,
-                 uint32_t   n_seq_max,
-                 uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
-    llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
-    llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
-
-    const uint32_t size_base = kv_size;
-
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
-
-    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
-    if (swa_full) {
-        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
-                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-
-        size_swa = size_base;
-    }
-
-    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
-
-    kv_base = std::make_unique<llama_kv_cache_unified>(
-            model, std::move(filter_base), type_k, type_v,
-            v_trans, offload, unified, size_base, n_seq_max, n_pad,
-            0, LLAMA_SWA_TYPE_NONE);
-
-    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
-
-    kv_swa = std::make_unique<llama_kv_cache_unified>(
-            model, std::move(filter_swa), type_k, type_v,
-            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
-            hparams.n_swa, hparams.swa_type);
-}
-
-void llama_kv_cache_unified_iswa::clear(bool data) {
-    kv_base->clear(data);
-    kv_swa ->clear(data);
-}
-
-bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    bool res = true;
-
-    res = res & kv_base->seq_rm(seq_id, p0, p1);
-    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
-
-    return res;
-}
-
-void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
-    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
-    kv_base->seq_keep(seq_id);
-    kv_swa ->seq_keep(seq_id);
-}
-
-void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
-    kv_base->seq_add(seq_id, p0, p1, shift);
-    kv_swa ->seq_add(seq_id, p0, p1, shift);
-}
-
-void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    kv_base->seq_div(seq_id, p0, p1, d);
-    kv_swa ->seq_div(seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
-    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
-    return kv_swa->seq_pos_min(seq_id);
-}
-
-llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
-    return kv_swa->seq_pos_max(seq_id);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
-    GGML_UNUSED(embd_all);
-
-    // first try simple split
-    do {
-        if (!unified) {
-            // requires equal splits, so we skip the simple split
-            break;
-        }
-
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = balloc.split_simple(n_ubatch);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos_base = kv_base->prepare(ubatches);
-        if (sinfos_base.empty()) {
-            break;
-        }
-
-        auto sinfos_swa = kv_swa->prepare(ubatches);
-        if (sinfos_swa.empty()) {
-            break;
-        }
-
-        assert(sinfos_base.size() == sinfos_swa.size());
-
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
-    } while (false);
-
-    // if it fails, try equal split
-    do {
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch, !unified);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos_base = kv_base->prepare(ubatches);
-        if (sinfos_base.empty()) {
-            break;
-        }
-
-        auto sinfos_swa = kv_swa->prepare(ubatches);
-        if (sinfos_swa.empty()) {
-            break;
-        }
-
-        assert(sinfos_base.size() == sinfos_swa.size());
-
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
-    } while (false);
-
-    // TODO: if we fail again, we should attempt different splitting strategies
-    //       but to do that properly, we first have to refactor the batches to be more flexible
-
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
-}
-
-bool llama_kv_cache_unified_iswa::get_can_shift() const {
-    return kv_base->get_size() == kv_swa->get_size();
-}
-
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
-        kv_base->state_write(io, seq_id, flags);
-    }
-
-    kv_swa->state_write(io, seq_id, flags);
-}
-
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
-        kv_base->state_read(io, seq_id, flags);
-    }
-
-    kv_swa->state_read(io, seq_id, flags);
-}
-
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
-    return kv_base.get();
-}
-
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
-    return kv_swa.get();
-}
-
-//
-// llama_kv_cache_unified_iswa_context
-//
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv) :
-    ctx_base(kv->get_base()->init_full()),
-    ctx_swa (kv->get_swa ()->init_full()),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
-        llama_context * lctx,
-        bool optimize) :
-    ctx_base(kv->get_base()->init_update(lctx, optimize)),
-    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
-        slot_info_vec_t sinfos_base,
-        slot_info_vec_t sinfos_swa,
-        std::vector<llama_ubatch> ubatches) :
-    ubatches(std::move(ubatches)),
-    // note: here we copy the ubatches. not sure if this is ideal
-    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
-    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
-
-bool llama_kv_cache_unified_iswa_context::next() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    ctx_base->next();
-    ctx_swa ->next();
-
-    if (++i_next >= ubatches.size()) {
-        return false;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified_iswa_context::apply() {
-    assert(!llama_memory_status_is_fail(status));
-
-    bool res = true;
-
-    res = res & ctx_base->apply();
-    res = res & ctx_swa ->apply();
-
-    return res;
-}
-
-llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
-    return status;
-}
-
-const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return ubatches[i_next];
-}
-
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
-}
-
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa()  const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
-}
diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h
deleted file mode 100644 (file)
index 7bc4df7..0000000
+++ /dev/null
@@ -1,133 +0,0 @@
-#pragma once
-
-#include "llama-kv-cache-unified.h"
-
-#include <vector>
-
-//
-// llama_kv_cache_unified_iswa
-//
-
-// utilizes two instances of llama_kv_cache_unified
-//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
-
-class llama_kv_cache_unified_iswa : public llama_memory_i {
-public:
-    llama_kv_cache_unified_iswa(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   v_trans,
-                         bool   offload,
-                         bool   swa_full,
-                         bool   unified,
-                     uint32_t   kv_size,
-                     uint32_t   n_seq_max,
-                     uint32_t   n_ubatch,
-                     uint32_t   n_pad);
-
-    ~llama_kv_cache_unified_iswa() = default;
-
-    //
-    // llama_memory_i
-    //
-
-    llama_memory_context_ptr init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) override;
-
-    llama_memory_context_ptr init_full() override;
-
-    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
-
-    bool get_can_shift() const override;
-
-    void clear(bool data) override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    // state write/load
-
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
-
-    //
-    // llama_kv_cache_unified_iswa specific API
-    //
-
-    llama_kv_cache_unified * get_base() const;
-    llama_kv_cache_unified * get_swa () const;
-
-private:
-    const llama_hparams & hparams;
-
-    const bool unified;
-
-    std::unique_ptr<llama_kv_cache_unified> kv_base;
-    std::unique_ptr<llama_kv_cache_unified> kv_swa;
-};
-
-class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
-public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
-
-    // used for errors
-    llama_kv_cache_unified_iswa_context(llama_memory_status status);
-
-    // used to create a full-cache context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv);
-
-    // used to create an update context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
-            llama_context * lctx,
-            bool optimize);
-
-    // used to create a batch processing context from a batch
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
-            slot_info_vec_t sinfos_base,
-            slot_info_vec_t sinfos_swa,
-            std::vector<llama_ubatch> ubatches);
-
-    virtual ~llama_kv_cache_unified_iswa_context();
-
-    //
-    // llama_memory_context_i
-    //
-
-    bool next()  override;
-    bool apply() override;
-
-    llama_memory_status  get_status() const override;
-    const llama_ubatch & get_ubatch() const override;
-
-    //
-    // llama_kv_cache_unified_iswa_context specific API
-    //
-
-    const llama_kv_cache_unified_context * get_base() const;
-    const llama_kv_cache_unified_context * get_swa()  const;
-
-private:
-    //llama_kv_cache_unified_iswa * kv;
-
-    // the index of the next ubatch to process
-    size_t i_next = 0;
-
-    std::vector<llama_ubatch> ubatches;
-
-    const llama_memory_context_ptr ctx_base;
-    const llama_memory_context_ptr ctx_swa;
-
-    const llama_memory_status status;
-};
diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp
deleted file mode 100644 (file)
index 478ebff..0000000
+++ /dev/null
@@ -1,2410 +0,0 @@
-#include "llama-kv-cache-unified.h"
-
-#include "llama-impl.h"
-#include "llama-io.h"
-#include "llama-model.h"
-#include "llama-context.h"
-
-#include <algorithm>
-#include <cassert>
-#include <cmath>
-#include <limits>
-#include <map>
-#include <stdexcept>
-
-//
-// llama_kv_cache_unified
-//
-
-llama_kv_cache_unified::llama_kv_cache_unified(
-        const llama_model &  model,
-          layer_filter_cb && filter,
-                ggml_type    type_k,
-                ggml_type    type_v,
-                     bool    v_trans,
-                     bool    offload,
-                     bool    unified,
-                 uint32_t    kv_size,
-                 uint32_t    n_seq_max,
-                 uint32_t    n_pad,
-                 uint32_t    n_swa,
-           llama_swa_type    swa_type) :
-    model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
-
-    GGML_ASSERT(kv_size % n_pad == 0);
-
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    auto n_layer_cache = hparams.n_layer;
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        n_layer_cache = 20;
-    }
-    if (model.arch == LLM_ARCH_GLM4_MOE) {
-        // GLM-4.5: Only process up to last layer, skip final NextN layer
-        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
-    }
-
-    // create a context for each buffer type
-    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                return nullptr;
-            }
-
-            ctx_map[buft] = ctx;
-            ctxs.emplace_back(ctx);
-
-            return ctx;
-        }
-
-        return it->second;
-    };
-
-    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
-
-    v_heads.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_heads[s] = 0;
-    }
-
-    v_cells.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_cells[s].resize(kv_size);
-    }
-
-    // by default, all sequence ids are mapped to the 0th stream
-    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
-
-    if (n_stream > 1) {
-        seq_to_stream.resize(n_stream, 0);
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            seq_to_stream[s] = s;
-        }
-    }
-
-    // [TAG_V_CACHE_VARIABLE]
-    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
-        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
-                __func__, hparams.n_embd_v_gqa_max());
-    }
-
-    for (uint32_t il = 0; il < n_layer_cache; il++) {
-        if (filter && !filter(il)) {
-            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
-            continue;
-        }
-
-        // [TAG_V_CACHE_VARIABLE]
-        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
-        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
-
-        const char * dev_name = "CPU";
-
-        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
-
-        if (offload) {
-            auto * dev = model.dev_layer(il);
-            buft = ggml_backend_dev_buffer_type(dev);
-
-            dev_name = ggml_backend_dev_name(dev);
-        }
-
-        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
-
-        ggml_context * ctx = ctx_for_buft(buft);
-        if (!ctx) {
-            throw std::runtime_error("failed to create ggml context for kv cache");
-        }
-
-        ggml_tensor * k;
-        ggml_tensor * v;
-
-        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
-        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
-
-        ggml_format_name(k, "cache_k_l%d", il);
-        ggml_format_name(v, "cache_v_l%d", il);
-
-        std::vector<ggml_tensor *> k_stream;
-        std::vector<ggml_tensor *> v_stream;
-
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
-            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
-        }
-
-        map_layer_ids[il] = layers.size();
-
-        layers.push_back({ il, k, v, k_stream, v_stream, });
-    }
-
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
-
-        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
-            if (filter && !filter(il)) {
-                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
-                continue;
-            }
-
-            const bool     is_swa   = hparams.is_swa(il);
-            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
-
-            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
-            map_layer_ids[il] = map_layer_ids[il_reuse];
-
-            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
-        }
-    }
-
-    // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            throw std::runtime_error("failed to allocate buffer for kv cache");
-        }
-
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-
-        ggml_backend_buffer_clear(buf, 0);
-        bufs.emplace_back(buf);
-    }
-
-    {
-        const size_t memory_size_k = size_k_bytes();
-        const size_t memory_size_v = size_v_bytes();
-
-        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-    }
-
-    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
-    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
-
-    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
-
-    if (!supports_set_rows) {
-        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
-        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
-    }
-
-    if (!supports_set_rows) {
-        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
-    }
-}
-
-void llama_kv_cache_unified::clear(bool data) {
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_cells[s].reset();
-        v_heads[s] = 0;
-    }
-
-    if (data) {
-        for (auto & buf : bufs) {
-            ggml_backend_buffer_clear(buf.get(), 0);
-        }
-    }
-}
-
-bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    if (seq_id >= 0) {
-        auto & cells = v_cells[seq_to_stream[seq_id]];
-        auto & head  = v_heads[seq_to_stream[seq_id]];
-
-        uint32_t new_head = cells.size();
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.pos_in(i, p0, p1)) {
-                continue;
-            }
-
-            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-        }
-
-        // If we freed up a slot, set head to it so searching can start there.
-        if (new_head != cells.size() && new_head < head) {
-            head = new_head;
-        }
-    } else {
-        // match any sequence
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            auto & cells = v_cells[s];
-            auto & head  = v_heads[s];
-
-            uint32_t new_head = cells.size();
-
-            for (uint32_t i = 0; i < cells.size(); ++i) {
-                if (!cells.pos_in(i, p0, p1)) {
-                    continue;
-                }
-
-                cells.rm(i);
-
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-
-            // If we freed up a slot, set head to it so searching can start there.
-            if (new_head != cells.size() && new_head < head) {
-                head = new_head;
-            }
-        }
-    }
-
-    return true;
-}
-
-void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
-    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
-
-    const auto s0 = seq_to_stream[seq_id_src];
-    const auto s1 = seq_to_stream[seq_id_dst];
-
-    if (s0 == s1) {
-        // since both sequences are in the same stream, no data copy is necessary
-        // we just have to update the cells meta data
-
-        auto & cells = v_cells[s0];
-
-        if (seq_id_src == seq_id_dst) {
-            return;
-        }
-
-        if (p0 < 0) {
-            p0 = 0;
-        }
-
-        if (p1 < 0) {
-            p1 = std::numeric_limits<llama_pos>::max();
-        }
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.pos_in(i, p0, p1)) {
-                continue;
-            }
-
-            if (cells.seq_has(i, seq_id_src)) {
-                cells.seq_add(i, seq_id_dst);
-            }
-        }
-
-        return;
-    }
-
-    // cross-stream sequence copies require to copy the actual buffer data
-
-    bool is_full = true;
-
-    if (p0 > 0 && p0 + 1 < (int) get_size()) {
-        is_full = false;
-    }
-
-    if (p1 > 0 && p1 + 1 < (int) get_size()) {
-        is_full = false;
-    }
-
-    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
-
-    // enqueue the copy operation - the buffer copy will be performed during the next update
-    sc_info.ssrc.push_back(s0);
-    sc_info.sdst.push_back(s1);
-
-    v_cells[s1].reset();
-    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
-        if (v_cells[s0].seq_has(i, seq_id_src)) {
-            llama_pos pos   = v_cells[s0].pos_get(i);
-            llama_pos shift = v_cells[s0].get_shift(i);
-
-            if (shift != 0) {
-                pos -= shift;
-                assert(pos >= 0);
-            }
-
-            v_cells[s1].pos_set(i, pos);
-            v_cells[s1].seq_add(i, seq_id_dst);
-
-            if (shift != 0) {
-                v_cells[s1].pos_add(i, shift);
-            }
-        }
-    }
-
-    v_heads[s1] = v_heads[s0];
-
-    //for (uint32_t s = 0; s < n_stream; ++s) {
-    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
-    //}
-}
-
-void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-    auto & head  = v_heads[seq_to_stream[seq_id]];
-
-    uint32_t new_head = cells.size();
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (cells.seq_keep(i, seq_id)) {
-            if (new_head == cells.size()) {
-                new_head = i;
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cells.size() && new_head < head) {
-        head = new_head;
-    }
-}
-
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-    auto & head  = v_heads[seq_to_stream[seq_id]];
-
-    if (shift == 0) {
-        return;
-    }
-
-    uint32_t new_head = cells.size();
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over all cells.
-    if (p0 == p1) {
-        return;
-    }
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
-
-        if (cells.seq_has(i, seq_id)) {
-            if (cells.pos_add(i, shift)) {
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    // Otherwise we just start the next search from the beginning.
-    head = new_head != cells.size() ? new_head : 0;
-}
-
-void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    if (d == 1) {
-        return;
-    }
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) {
-        return;
-    }
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
-
-        if (cells.seq_has(i, seq_id)) {
-            cells.pos_div(i, d);
-        }
-    }
-}
-
-llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    return cells.seq_pos_min(seq_id);
-}
-
-llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    return cells.seq_pos_max(seq_id);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) {
-    GGML_UNUSED(embd_all);
-
-    do {
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos = prepare(ubatches);
-        if (sinfos.empty()) {
-            break;
-        }
-
-        return std::make_unique<llama_kv_cache_unified_context>(
-                this, std::move(sinfos), std::move(ubatches));
-    } while (false);
-
-    return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_context>(this);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
-    bool do_shift = get_has_shift();
-
-    defrag_info dinfo;
-
-    // see if we need to defrag
-    if (n_stream == 1) {
-        // note : for now do not consider defrag for n_stream > 1
-        const auto & cells = v_cells[seq_to_stream[0]];
-
-        bool do_defrag = optimize;
-
-        const auto thold = lctx->get_cparams().defrag_thold;
-
-        if (!do_defrag && thold > 0.0f) {
-            const auto n_kv = cells.used_max_p1();
-
-            // - do not defrag small contexts (i.e. < 2048 tokens)
-            // - count the padding towards the number of used tokens
-            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
-
-            if (fragmentation > thold) {
-                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-                do_defrag = true;
-            }
-        }
-
-        if (do_defrag) {
-            dinfo = defrag_prepare(lctx->graph_max_nodes());
-        }
-    }
-
-    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
-}
-
-llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    llama_kv_cache_unified::slot_info_vec_t res;
-
-    struct state_t {
-        slot_info sinfo; // slot info for the ubatch
-
-        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
-
-        std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
-    };
-
-    // remember the old state of the cells so we can restore it in the end
-    std::vector<state_t> states;
-
-    bool success = true;
-
-    for (const auto & ubatch : ubatches) {
-        // non-continuous slots require support for ggml_set_rows()
-        const bool cont = supports_set_rows ? false : true;
-
-        // only find a suitable slot for the ubatch. don't modify the cells yet
-        const auto sinfo_new = find_slot(ubatch, cont);
-        if (sinfo_new.empty()) {
-            success = false;
-            break;
-        }
-
-        // remeber the position that we found
-        res.push_back(sinfo_new);
-
-        // store the old state of the cells in the recovery stack
-        {
-            state_t state = { sinfo_new, v_heads, {} };
-
-            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
-                auto & cells = v_cells[sinfo_new.strm[s]];
-
-                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
-            }
-
-            states.push_back(std::move(state));
-        }
-
-        // now emplace the ubatch
-        apply_ubatch(sinfo_new, ubatch);
-    }
-
-    GGML_ASSERT(!states.empty() || !success);
-
-    // iterate backwards and restore the cells to their original state
-    for (auto it = states.rbegin(); it != states.rend(); ++it) {
-        const auto & sinfo = it->sinfo;
-
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            auto & cells = v_cells[sinfo.strm[s]];
-            auto & head  = v_heads[sinfo.strm[s]];
-
-            cells.set(sinfo.idxs[s], it->v_cells[s]);
-            head = it->v_heads_old[s];
-        }
-    }
-
-    if (!success) {
-        return {};
-    }
-
-    return res;
-}
-
-bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
-    bool updated = false;
-
-    auto * sched = lctx->get_sched();
-
-    if (!sc_info.empty()) {
-        assert(n_stream > 1 && "stream copy should never happen with a single stream");
-
-        llama_synchronize(lctx);
-
-        const size_t n_copy = sc_info.ssrc.size();
-
-        for (size_t i = 0; i < n_copy; ++i) {
-            const auto ssrc = sc_info.ssrc[i];
-            const auto sdst = sc_info.sdst[i];
-
-            assert(ssrc < n_stream);
-            assert(sdst < n_stream);
-
-            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
-
-            assert(ssrc != sdst);
-
-            for (uint32_t il = 0; il < layers.size(); ++il) {
-                const auto & layer = layers[il];
-
-                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
-                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
-            }
-        }
-    }
-
-    if (do_shift) {
-        if (!get_can_shift()) {
-            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
-        }
-
-        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
-
-        // apply K-shift if needed
-        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
-            ggml_backend_sched_reset(sched);
-
-            auto * res = lctx->get_gf_res_reserve();
-
-            res->reset();
-
-            auto * gf = build_graph_shift(res, lctx);
-            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
-                return updated;
-            }
-
-            res->set_inputs(nullptr);
-
-            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
-                return updated;
-            }
-
-            updated = true;
-        }
-
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            auto & cells = v_cells[s];
-
-            cells.reset_shift();
-        }
-    }
-
-    if (!dinfo.empty()) {
-        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
-
-        // note: for now do not consider defrag for n_stream > 1
-        auto & cells = v_cells[seq_to_stream[0]];
-        auto & head  = v_heads[seq_to_stream[0]];
-
-        // apply moves:
-        {
-            const auto n_kv = dinfo.ids.size();
-
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                assert(dinfo.ids[i] <= n_kv);
-
-                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
-                    continue;
-                }
-
-                cells.mv(i, dinfo.ids[i]);
-            }
-
-            // reset the head so we can find the first free slot during the next ubatch
-            head = 0;
-        }
-
-        ggml_backend_sched_reset(sched);
-
-        auto * res = lctx->get_gf_res_reserve();
-
-        res->reset();
-
-        auto * gf = build_graph_defrag(res, lctx, dinfo);
-        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
-            return updated;
-        }
-
-        res->set_inputs(nullptr);
-
-        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
-            return updated;
-        }
-
-        updated = true;
-    }
-
-    return updated;
-}
-
-llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
-
-    if (debug > 0) {
-        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
-            const auto seq_id = ubatch.seq_id_unq[s];
-            const auto stream_id = seq_to_stream[seq_id];
-            const auto & cells = v_cells[stream_id];
-            const uint32_t head_cur = v_heads[stream_id];
-
-            LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
-                    __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
-
-            if ((debug == 2 && n_swa > 0) || debug > 2) {
-                std::string ss;
-                for (uint32_t i = 0; i < cells.size(); ++i) {
-                    if (cells.is_empty(i)) {
-                        ss += '.';
-                    } else {
-                        assert(cells.seq_count(i) >= 1);
-
-                        if (cells.seq_count(i) == 1) {
-                            ss += std::to_string(cells.seq_get(i));
-                        } else {
-                            ss += 'M';
-                        }
-                    }
-                    if (i%256 == 255) {
-                        ss += " *";
-                        ss += '\n';
-                    }
-                }
-                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
-            }
-
-            if ((debug == 2 && n_swa > 0) || debug > 2) {
-                std::string ss;
-                for (uint32_t i = 0; i < cells.size(); ++i) {
-                    std::string cur;
-                    if (cells.is_empty(i)) {
-                        cur = '.';
-                    } else {
-                        cur = std::to_string(cells.pos_get(i));
-                    }
-                    const int n = cur.size();
-                    for (int j = 0; j < 5 - n; ++j) {
-                        cur += ' ';
-                    }
-                    ss += cur;
-                    if (i%256 == 255) {
-                        ss += " *";
-                    }
-                    if (i%64 == 63) {
-                        ss += '\n';
-                    }
-                }
-                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
-            }
-
-            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
-                if (cells.seq_pos_min(s) < 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
-            }
-        }
-    }
-
-    uint32_t n_tokens = ubatch.n_tokens;
-    uint32_t n_seqs   = 1;
-
-    if (n_stream > 1) {
-        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
-
-        n_seqs   = ubatch.n_seqs_unq;
-        n_tokens = n_tokens / n_seqs;
-    }
-
-    slot_info res = {
-        /*.s0   =*/ LLAMA_MAX_SEQ,
-        /*.s1   =*/ 0,
-        /*.strm =*/ { },
-        /*.idxs =*/ { },
-    };
-
-    res.resize(n_seqs);
-
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const auto seq_id = ubatch.seq_id_unq[s];
-
-        if (n_stream > 1) {
-            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
-            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
-        }
-
-        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
-        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
-
-        res.strm[s] = seq_to_stream[seq_id];
-        res.idxs[s].reserve(n_tokens);
-
-        const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
-
-        // if we have enough unused cells before the current head ->
-        //   better to start searching from the beginning of the cache, hoping to fill it
-        if (head_cur > cells.get_used() + 2*n_tokens) {
-            head_cur = 0;
-        }
-
-        if (n_tokens > cells.size()) {
-            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-            return { };
-        }
-
-        uint32_t n_tested = 0;
-
-        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
-        // for non-continuous slots, we test the tokens one by one
-        const uint32_t n_test = cont ? n_tokens : 1;
-
-        while (true) {
-            if (head_cur + n_test > cells.size()) {
-                n_tested += cells.size() - head_cur;
-                head_cur = 0;
-                continue;
-            }
-
-            for (uint32_t i = 0; i < n_test; i++) {
-                const auto idx = head_cur;
-
-                head_cur++;
-                n_tested++;
-
-                //const llama_pos    pos    = ubatch.pos[i];
-                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
-
-                // can we use this cell? either:
-                //  - the cell is empty
-                //  - the cell is occupied only by one sequence:
-                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
-                //    - mask SWA, using current max pos for that sequence in the cache
-                //                always insert in the cell with minimum pos
-                bool can_use = cells.is_empty(idx);
-
-                if (!can_use && cells.seq_count(idx) == 1) {
-                    const llama_pos pos_cell = cells.pos_get(idx);
-
-                    // (disabled) causal mask
-                    // note: it's better to purge any "future" tokens beforehand
-                    //if (cells.seq_has(idx, seq_id)) {
-                    //    can_use = pos_cell >= pos;
-                    //}
-
-                    if (!can_use) {
-                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
-
-                        // SWA mask
-                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                            can_use = true;
-                        }
-                    }
-                }
-
-                if (can_use) {
-                    res.idxs[s].push_back(idx);
-                } else {
-                    if (cont) {
-                        break;
-                    }
-                }
-            }
-
-            if (res.idxs[s].size() == n_tokens) {
-                break;
-            }
-
-            if (cont) {
-                res.idxs[s].clear();
-            }
-
-            if (n_tested >= cells.size()) {
-                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-                return { };
-            }
-        }
-
-        // we didn't find a suitable slot - return empty result
-        if (res.idxs[s].size() < n_tokens) {
-            return { };
-        }
-    }
-
-    assert(res.s1 >= res.s0);
-
-    return res;
-}
-
-void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
-    // keep track of the max sequence position that we would overwrite with this ubatch
-    // for non-SWA cache, this would be always empty
-    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
-    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
-        seq_pos_max_rm[s] = -1;
-    }
-
-    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
-
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
-            const uint32_t i = s*sinfo.size() + ii;
-
-            auto & cells = v_cells[sinfo.strm[s]];
-
-            const auto idx = sinfo.idxs[s][ii];
-
-            if (!cells.is_empty(idx)) {
-                assert(cells.seq_count(idx) == 1);
-
-                const llama_seq_id seq_id = cells.seq_get(idx);
-                const llama_pos    pos    = cells.pos_get(idx);
-
-                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
-
-                cells.rm(idx);
-            }
-
-            cells.pos_set(idx, ubatch.pos[i]);
-
-            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
-                cells.seq_add(idx, ubatch.seq_id[i][s]);
-            }
-        }
-    }
-
-    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
-    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
-    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
-        if (seq_pos_max_rm[s] == -1) {
-            continue;
-        }
-
-        GGML_ASSERT(s < seq_to_stream.size());
-
-        auto & cells = v_cells[seq_to_stream[s]];
-
-        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
-            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
-                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
-
-            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
-        }
-    }
-
-    // move the head at the end of the slot
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        auto & head = v_heads[sinfo.strm[s]];
-
-        head = sinfo.idxs[s].back() + 1;
-    }
-}
-
-bool llama_kv_cache_unified::get_can_shift() const {
-    return true;
-}
-
-uint32_t llama_kv_cache_unified::get_size() const {
-    const auto & cells = v_cells[seq_to_stream[0]];
-
-    return cells.size();
-}
-
-uint32_t llama_kv_cache_unified::get_n_stream() const {
-    return n_stream;
-}
-
-bool llama_kv_cache_unified::get_has_shift() const {
-    bool result = false;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        result |= v_cells[s].get_has_shift();
-    }
-
-    return result;
-}
-
-uint32_t llama_kv_cache_unified::get_n_kv() const {
-    uint32_t result = 0;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
-
-        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
-    }
-
-    return result;
-}
-
-bool llama_kv_cache_unified::get_supports_set_rows() const {
-    return supports_set_rows;
-}
-
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * k = layers[ikv].k;
-
-    const uint64_t kv_size      = get_size();
-    const uint64_t n_embd_k_gqa = k->ne[0];
-
-    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
-
-    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
-
-    return ggml_view_4d(ctx, k,
-            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
-            ggml_row_size(k->type, hparams.n_embd_head_k),
-            ggml_row_size(k->type, n_embd_k_gqa),
-            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
-            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
-}
-
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * v = layers[ikv].v;
-
-    const uint64_t kv_size      = get_size();
-    const uint64_t n_embd_v_gqa = v->ne[0];
-
-    // [TAG_V_CACHE_VARIABLE]
-    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
-
-    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
-
-    if (!v_trans) {
-        // note: v->nb[1] <= v->nb[2]
-        return ggml_view_4d(ctx, v,
-                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
-                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
-                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
-    }
-
-    // note: v->nb[1] > v->nb[2]
-    return ggml_view_4d(ctx, v,
-            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
-            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
-            ggml_row_size(v->type, kv_size),                          // v->nb[2]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
-}
-
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * k = layers[ikv].k;
-
-    const int64_t n_embd_k_gqa = k->ne[0];
-    const int64_t n_tokens = k_cur->ne[2];
-
-    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
-
-    if (k_idxs && supports_set_rows) {
-        if (k->ne[2] > 1) {
-            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
-        }
-
-        return ggml_set_rows(ctx, k, k_cur, k_idxs);
-    }
-
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
-
-    ggml_tensor * k_view = ggml_view_1d(ctx, k,
-            n_tokens*n_embd_k_gqa,
-            ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
-
-    return ggml_cpy(ctx, k_cur, k_view);
-}
-
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * v = layers[ikv].v;
-
-    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
-    const int64_t n_tokens     = v_cur->ne[2];
-
-    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
-
-    if (v_idxs && supports_set_rows) {
-        if (!v_trans) {
-            if (v->ne[2] > 1) {
-                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
-            }
-
-            return ggml_set_rows(ctx, v, v_cur, v_idxs);
-        }
-
-        // [TAG_V_CACHE_VARIABLE]
-        if (n_embd_v_gqa < v->ne[0]) {
-            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
-        }
-
-        // the row becomes a single element
-        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
-
-        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
-
-        return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
-    }
-
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
-
-    ggml_tensor * v_view = nullptr;
-
-    if (!v_trans) {
-        v_view = ggml_view_1d(ctx, v,
-                n_tokens*n_embd_v_gqa,
-                ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
-    } else {
-        v_cur = ggml_transpose(ctx, v_cur);
-
-        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
-                (v->ne[1]    )*ggml_element_size(v),
-                (sinfo.head())*ggml_element_size(v));
-    }
-
-    return ggml_cpy(ctx, v_cur, v_view);
-}
-
-ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
-
-    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
-
-    ggml_set_input(k_idxs);
-
-    return k_idxs;
-}
-
-ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
-
-    ggml_tensor * v_idxs;
-
-    if (!v_trans) {
-        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
-    } else {
-        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
-    }
-
-    ggml_set_input(v_idxs);
-
-    return v_idxs;
-}
-
-void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
-    const uint32_t n_tokens = ubatch->n_tokens;
-    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    int64_t * data = (int64_t *) dst->data;
-
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        const int64_t offs = sinfo.strm[s]*get_size();
-
-        for (uint32_t i = 0; i < sinfo.size(); ++i) {
-            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
-    const uint32_t n_tokens = ubatch->n_tokens;
-    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    int64_t * data = (int64_t *) dst->data;
-
-    if (!v_trans) {
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            const int64_t offs = sinfo.strm[s]*get_size();
-
-            for (uint32_t i = 0; i < sinfo.size(); ++i) {
-                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
-            }
-        }
-    } else {
-        // note: the V cache is transposed when not using flash attention
-        const int64_t kv_size = get_size();
-
-        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
-
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
-
-            for (uint32_t i = 0; i < sinfo.size(); ++i) {
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
-                }
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-
-    int32_t * data = (int32_t *) dst->data;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const uint32_t n_tokens = ubatch->n_tokens;
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    float * data = (float *) dst->data;
-
-    const int64_t n_kv     = dst->ne[0];
-    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
-
-    GGML_ASSERT(n_tokens%n_stream == 0);
-
-    // n_tps == n_tokens_per_stream
-    const int64_t n_tps     = n_tokens/n_stream;
-    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
-
-    std::fill(data, data + ggml_nelements(dst), -INFINITY);
-
-    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
-    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
-    //   Causal mask:
-    //      xxx-------
-    //      xxxx------
-    //      xxxxx-----
-    //   Non-causal mask:
-    //      xxxxx-----
-    //      xxxxx-----
-    //      xxxxx-----
-    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-    // TODO: optimize this section
-    for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            for (uint32_t ii = 0; ii < n_tps; ++ii) {
-                const uint32_t i = s*n_tps + ii;
-
-                const llama_seq_id seq_id = ubatch->seq_id[i][0];
-
-                const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-                const llama_pos p1 = ubatch->pos[i];
-
-                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
-
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    if (cells.is_empty(j)) {
-                        continue;
-                    }
-
-                    // mask the token if not the same sequence
-                    if (!cells.seq_has(j, seq_id)) {
-                        continue;
-                    }
-
-                    const llama_pos p0 = cells.pos_get(j);
-
-                    // mask future tokens
-                    if (causal_attn && p0 > p1) {
-                        continue;
-                    }
-
-                    // apply SWA if any
-                    if (is_masked_swa(p0, p1)) {
-                        continue;
-                    }
-
-                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
-                }
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    const int64_t n_tokens = ubatch->n_tokens;
-
-    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
-    const auto & cells = v_cells[0];
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
-
-    int32_t * data = (int32_t *) dst->data;
-
-    const int32_t n_kv = dst->ne[0];
-
-    for (int h = 0; h < 1; ++h) {
-        for (int i = 0; i < n_tokens; ++i) {
-            for (int j = 0; j < n_kv; ++j) {
-                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
-                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
-
-                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
-            }
-        }
-    }
-}
-
-size_t llama_kv_cache_unified::total_size() const {
-    size_t size = 0;
-
-    for (const auto & buf : bufs) {
-        size += ggml_backend_buffer_get_size(buf.get());
-    }
-
-    return size;
-}
-
-size_t llama_kv_cache_unified::size_k_bytes() const {
-    size_t size_k_bytes = 0;
-
-    for (const auto & layer : layers) {
-        size_k_bytes += ggml_nbytes(layer.k);
-    }
-
-    return size_k_bytes;
-}
-
-size_t llama_kv_cache_unified::size_v_bytes() const {
-    size_t size_v_bytes = 0;
-
-    for (const auto & layer : layers) {
-        size_v_bytes += ggml_nbytes(layer.v);
-    }
-
-    return size_v_bytes;
-}
-
-ggml_tensor * llama_kv_cache_unified::build_rope_shift(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_tensor * cur,
-                ggml_tensor * shift,
-                ggml_tensor * factors,
-                      float   freq_base,
-                      float   freq_scale) const {
-    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
-
-    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
-    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
-    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
-
-    const auto & n_rot     = hparams.n_rot;
-    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
-                                // @ngxson : this is a workaround
-                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
-                                // a normal RoPE should work, we just need to use the correct ordering
-                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
-                                ? LLAMA_ROPE_TYPE_NEOX
-                                : hparams.rope_type;
-
-    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
-                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
-                                    : cparams.yarn_attn_factor;
-
-    ggml_tensor * tmp;
-
-    if (ggml_is_quantized(cur->type)) {
-        // dequantize to f32 -> RoPE -> quantize back
-        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
-
-        tmp = ggml_rope_ext(ctx, tmp,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-
-        tmp = ggml_cpy(ctx, tmp, cur);
-    } else {
-        // we rotate only the first n_rot dimensions
-        tmp = ggml_rope_ext_inplace(ctx, cur,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-    }
-
-    return tmp;
-}
-
-class llm_graph_input_k_shift : public llm_graph_input_i {
-public:
-    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
-    virtual ~llm_graph_input_k_shift() = default;
-
-    void set_input(const llama_ubatch * ubatch) override;
-
-    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
-
-    const llama_kv_cache_unified * kv_self;
-};
-
-void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
-    GGML_UNUSED(ubatch);
-
-    if (k_shift) {
-        kv_self->set_input_k_shift(k_shift);
-    }
-}
-
-ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
-    auto * ctx = res->get_ctx();
-    auto * gf  = res->get_gf();
-
-    const auto & n_embd_head_k = hparams.n_embd_head_k;
-  //const auto & n_embd_head_v = hparams.n_embd_head_v;
-
-    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
-
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
-    ggml_set_input(inp->k_shift);
-
-    const auto & cparams = lctx->get_cparams();
-
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const int64_t n_head_kv    = hparams.n_head_kv(il);
-        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
-        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
-
-        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
-
-        ggml_tensor * k =
-            ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, get_size()*n_stream,
-                ggml_row_size(layer.k->type, n_embd_head_k),
-                ggml_row_size(layer.k->type, n_embd_k_gqa),
-                0);
-
-        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
-
-        ggml_build_forward_expand(gf, cur);
-    }
-
-    res->add_input(std::move(inp));
-
-    return gf;
-}
-
-ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
-         llm_graph_result * res,
-            llama_context * lctx,
-        const defrag_info & dinfo) const {
-    auto * ctx = res->get_ctx();
-    auto * gf  = res->get_gf();
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const auto & ids = dinfo.ids;
-
-    const auto & cparams = lctx->get_cparams();
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = size;
-
-    std::vector<uint8_t> buf_k;
-    std::vector<uint8_t> buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(v_l[il]->type);
-        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*i));
-
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*id));
-
-            ggml_tensor * view_v_src;
-            ggml_tensor * view_v_dst;
-
-            if (cparams.flash_attn) {
-                // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*id));
-            } else {
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, id));
-            }
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
-        }
-
-        i += nm - 1;
-    }
-
-    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-#endif
-
-    return gf;
-}
-
-llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const uint32_t n_layer = layers.size();
-
-    const uint32_t n_kv   = cells.used_max_p1();
-    const uint32_t n_used = cells.get_used();
-
-    assert(n_used <= n_kv);
-
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
-
-    // determine which KV cells to move where
-    defrag_info res;
-    auto & ids = res.ids;
-
-    ids.resize(n_kv, n_kv);
-
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        if (!cells.is_empty(i0)) {
-            ids[i0] = i0;
-
-            continue;
-        }
-
-        // found a hole - fill it with data from the end of the cache
-
-        uint32_t nh = 1;
-
-        // determine the size of the hole
-        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
-            nh++;
-        }
-
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
-
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            if (cells.is_empty(is) || ids[is] != n_kv) {
-                continue;
-            }
-
-            // non-empty cell which is not yet moved
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        // this can only happen if `n_used` is not accurate, which would be a bug
-        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
-
-        nf = 0;
-
-        uint32_t i1 = is;
-
-        // are we moving a continuous block of memory?
-        bool cont = false;
-
-        // should we stop searching for the next move?
-        bool stop = false;
-
-        // go back and move the nf cells to the hole
-        for (; i1 < n_kv; ++i1) {
-            if (cells.is_empty(i1) || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
-                cont = false;
-                continue;
-            }
-
-            // this cell goes to (i0 + nf)
-            ids[i1] = i0 + nf;
-
-            if (!cont) {
-                n_moves++;
-                cont = true;
-            }
-
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
-        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
-
-        i0 += nh - 1;
-    }
-
-    if (n_moves == 0) {
-        return {};
-    }
-
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
-
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
-
-    return res;
-}
-
-bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    assert(p0 >= 0 && p1 >= 0);
-
-    switch (swa_type) {
-        case LLAMA_SWA_TYPE_NONE:
-            {
-            } break;
-        case LLAMA_SWA_TYPE_STANDARD:
-            {
-                if (p1 - p0 >= (int32_t) n_swa) {
-                    return true;
-                }
-            } break;
-        case LLAMA_SWA_TYPE_CHUNKED:
-            {
-                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
-
-                if (p0 < pos_chunk_start) {
-                    return true;
-                }
-            } break;
-    }
-
-    return false;
-}
-
-void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    GGML_UNUSED(flags);
-
-    io.write(&n_stream, sizeof(n_stream));
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        cell_ranges_t cr { s, {} };
-
-        uint32_t cell_count = 0;
-
-        const auto & cells = v_cells[s];
-
-        // Count the number of cells with the specified seq_id
-        // Find all the ranges of cells with this seq id (or all, when -1)
-        uint32_t cell_range_begin = cells.size();
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
-                ++cell_count;
-                if (cell_range_begin == cells.size()) {
-                    cell_range_begin = i;
-                }
-            } else {
-                if (cell_range_begin != cells.size()) {
-                    cr.data.emplace_back(cell_range_begin, i);
-                    cell_range_begin = cells.size();
-                }
-            }
-        }
-
-        if (cell_range_begin != cells.size()) {
-            cr.data.emplace_back(cell_range_begin, cells.size());
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cr.data) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-
-        io.write(&cell_count, sizeof(cell_count));
-
-        // skip empty streams
-        if (cell_count == 0) {
-            continue;
-        }
-
-        state_write_meta(io, cr, seq_id);
-        state_write_data(io, cr);
-    }
-}
-
-void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    GGML_UNUSED(flags);
-
-    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
-
-    uint32_t n_stream_cur;
-    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
-    if (n_stream_cur != n_stream) {
-        throw std::runtime_error("n_stream mismatch");
-    }
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        uint32_t cell_count;
-        io.read_to(&cell_count, sizeof(cell_count));
-
-        if (cell_count == 0) {
-            continue;
-        }
-
-        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
-
-        bool res = true;
-        res = res && state_read_meta(io, strm, cell_count, seq_id);
-        res = res && state_read_data(io, strm, cell_count);
-
-        if (!res) {
-            if (seq_id == -1) {
-                clear(true);
-            } else {
-                seq_rm(seq_id, -1, -1);
-            }
-            throw std::runtime_error("failed to restore kv cache");
-        }
-    }
-}
-
-void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
-    const auto & cells = v_cells[cr.strm];
-
-    for (const auto & range : cr.data) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            std::vector<llama_seq_id> seq_ids;
-
-            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
-                if (cur == seq_id || seq_id == -1) {
-                    if (cells.seq_has(i, cur)) {
-                        seq_ids.push_back(cur);
-                    }
-                }
-            }
-
-            const llama_pos pos     = cells.pos_get(i);
-            const uint32_t n_seq_id = seq_ids.size();
-
-            io.write(&pos,      sizeof(pos));
-            io.write(&n_seq_id, sizeof(n_seq_id));
-
-            for (const auto & seq_id : seq_ids) {
-                io.write(&seq_id, sizeof(seq_id));
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
-    const auto & cells = v_cells[cr.strm];
-
-    const uint32_t v_trans = this->v_trans ? 1 : 0;
-    const uint32_t n_layer = layers.size();
-
-    io.write(&v_trans, sizeof(v_trans));
-    io.write(&n_layer, sizeof(n_layer));
-
-    std::vector<uint8_t> tmp_buf;
-
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        auto * k = layer.k_stream[cr.strm];
-
-        // Write key type
-        const int32_t k_type_i = (int32_t) k->type;
-        io.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
-        io.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cr.data) {
-            const size_t range_size = range.second - range.first;
-            const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(k, range.first * k_size_row, buf_size);
-        }
-    }
-
-    if (!v_trans) {
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[cr.strm];
-
-            // Write value type
-            const int32_t v_type_i = (int32_t) v->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
-            io.write(&v_size_row, sizeof(v_size_row));
-
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cr.data) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(v, range.first * v_size_row, buf_size);
-            }
-        }
-    } else {
-        // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = cells.size();
-
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[cr.strm];
-
-            // Write value type
-            const int32_t v_type_i = (int32_t) v->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const uint32_t v_size_el = ggml_type_size(v->type);
-            io.write(&v_size_el, sizeof(v_size_el));
-
-            // Write GQA embedding size
-            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cr.data) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(v, src_offset, buf_size);
-                }
-            }
-        }
-    }
-}
-
-bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
-    auto & cells = v_cells[strm];
-    auto & head  = v_heads[strm];
-
-    if (dest_seq_id != -1) {
-        // single sequence
-        seq_rm(dest_seq_id, -1, -1);
-
-        llama_batch_allocr balloc(hparams.n_pos_per_embd());
-
-        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
-
-        ubatch.seq_id_unq[0] = dest_seq_id;
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            uint32_t n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            if (n_seq_id != 1) {
-                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                return false;
-            }
-
-            // read the sequence id, but directly discard it - we will use dest_seq_id instead
-            {
-                llama_seq_id seq_id;
-                io.read_to(&seq_id, sizeof(seq_id));
-            }
-
-            ubatch.pos[i]      = pos;
-            ubatch.n_seq_id[i] = n_seq_id;
-            ubatch.seq_id[i]   = &dest_seq_id;
-        }
-
-        const auto sinfo = find_slot(ubatch, true);
-        if (sinfo.empty()) {
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return false;
-        }
-
-        apply_ubatch(sinfo, ubatch);
-
-        const auto head_cur = sinfo.head();
-
-        // keep the head at the old position because we will read the KV data into it in state_read_data()
-        head = head_cur;
-
-        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
-
-        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head_cur + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
-        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
-        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
-    } else {
-        // whole KV cache restore
-
-        if (cell_count > cells.size()) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-            return false;
-        }
-
-        clear(true);
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            uint32_t  n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            cells.pos_set(i, pos);
-
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                llama_seq_id seq_id;
-                io.read_to(&seq_id, sizeof(seq_id));
-
-                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
-                    return false;
-                }
-
-                cells.seq_add(i, seq_id);
-            }
-        }
-
-        head = 0;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
-    auto & cells = v_cells[strm];
-    auto & head  = v_heads[strm];
-
-    uint32_t v_trans;
-    uint32_t n_layer;
-
-    io.read_to(&v_trans, sizeof(v_trans));
-    io.read_to(&n_layer, sizeof(n_layer));
-
-    if (n_layer != layers.size()) {
-        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
-        return false;
-    }
-
-    if (cell_count > cells.size()) {
-        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
-        return false;
-    }
-
-    if (this->v_trans != (bool) v_trans) {
-        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-        return false;
-    }
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        auto * k = layer.k_stream[strm];
-
-        // Read type of key
-        int32_t k_type_i_ref;
-        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) k->type;
-        if (k_type_i != k_type_i_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return false;
-        }
-
-        // Read row size of key
-        uint64_t k_size_row_ref;
-        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-            return false;
-        }
-
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
-        }
-    }
-
-    if (!this->v_trans) {
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[strm];
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t) v->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read row size of value
-            uint64_t v_size_row_ref;
-            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[strm];
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t) v->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read element size of value
-            uint32_t v_size_el_ref;
-            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(v->type);
-            if (v_size_el != v_size_el_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                return false;
-            }
-
-            // Read GQA embedding size
-            uint32_t n_embd_v_gqa_ref;
-            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
-                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
-                }
-            }
-        }
-    }
-
-    return true;
-}
-
-//
-// llama_kv_cache_unified_context
-//
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
-    n_kv = kv->get_size();
-
-    const uint32_t n_stream = kv->get_n_stream();
-
-    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
-    sinfos.resize(1);
-    sinfos[0].s0 = 0;
-    sinfos[0].s1 = n_stream - 1;
-    sinfos[0].idxs.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        sinfos[0].strm.push_back(s);
-        sinfos[0].idxs[s].resize(1, 0);
-    }
-}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
-        llama_context * lctx,
-        bool do_shift,
-        defrag_info dinfo,
-        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
-    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
-        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
-    }
-}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
-        llama_kv_cache_unified::slot_info_vec_t sinfos,
-        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
-}
-
-llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
-
-bool llama_kv_cache_unified_context::next() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    if (++i_cur >= ubatches.size()) {
-        return false;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified_context::apply() {
-    assert(!llama_memory_status_is_fail(status));
-
-    // no ubatches -> this is a KV cache update
-    if (ubatches.empty()) {
-        kv->update(lctx, do_shift, dinfo, sc_info);
-
-        return true;
-    }
-
-    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
-
-    n_kv = kv->get_n_kv();
-
-    return true;
-}
-
-llama_memory_status llama_kv_cache_unified_context::get_status() const {
-    return status;
-}
-
-const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return ubatches[i_cur];
-}
-
-uint32_t llama_kv_cache_unified_context::get_n_kv() const {
-    return n_kv;
-}
-
-bool llama_kv_cache_unified_context::get_supports_set_rows() const {
-    return kv->get_supports_set_rows();
-}
-
-ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
-    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
-    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
-    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
-    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    return kv->build_input_k_idxs(ctx, ubatch);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    return kv->build_input_v_idxs(ctx, ubatch);
-}
-
-void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
-    kv->set_input_k_shift(dst);
-}
-
-void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
-}
-
-void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
-}
-
-void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    kv->set_input_kq_mask(dst, ubatch, causal_attn);
-}
-
-void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_pos_bucket(dst, ubatch);
-}
-
-uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h
deleted file mode 100644 (file)
index 07a7c9e..0000000
+++ /dev/null
@@ -1,399 +0,0 @@
-#pragma once
-
-#include "llama-batch.h"
-#include "llama-graph.h"
-#include "llama-kv-cells.h"
-#include "llama-memory.h"
-
-#include <unordered_map>
-#include <vector>
-
-struct llama_cparams;
-struct llama_hparams;
-struct llama_model;
-struct llama_context;
-
-//
-// llama_kv_cache_unified
-//
-
-class llama_kv_cache_unified : public llama_memory_i {
-public:
-    static uint32_t get_padding(const llama_cparams & cparams);
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
-    struct defrag_info {
-        bool empty() const {
-            return ids.empty();
-        }
-
-        // contains information about which cell moves where:
-        //  - cell i moves to ids[i]
-        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
-        std::vector<uint32_t> ids;
-    };
-
-    struct stream_copy_info {
-        bool empty() const {
-            assert(ssrc.size() == sdst.size());
-            return ssrc.empty();
-        }
-
-        std::vector<uint32_t> ssrc;
-        std::vector<uint32_t> sdst;
-    };
-
-    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
-    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
-    struct slot_info {
-        // data for ggml_set_rows
-        using idx_vec_t = std::vector<uint32_t>;
-
-        // number of streams: ns = s1 - s0 + 1
-        llama_seq_id s0;
-        llama_seq_id s1;
-
-        std::vector<llama_seq_id> strm; // [ns]
-        std::vector<idx_vec_t>    idxs; // [ns]
-
-        uint32_t head() const {
-            GGML_ASSERT(idxs.size() == 1);
-            GGML_ASSERT(!idxs[0].empty());
-
-            return idxs[0][0];
-        }
-
-        void resize(size_t n) {
-            strm.resize(n);
-            idxs.resize(n);
-        }
-
-        size_t size() const {
-            GGML_ASSERT(idxs.size() == strm.size());
-            GGML_ASSERT(!idxs.empty());
-
-            return idxs[0].size();
-        }
-
-        size_t n_stream() const {
-            return strm.size();
-        }
-
-        bool empty() const {
-            return idxs.empty();
-        }
-
-        void clear() {
-            idxs.clear();
-        }
-    };
-
-    using slot_info_vec_t = std::vector<slot_info>;
-
-    llama_kv_cache_unified(
-            const llama_model &  model,
-              layer_filter_cb && filter,
-                    ggml_type    type_k,
-                    ggml_type    type_v,
-                         bool    v_trans,
-                         bool    offload,
-                         bool    unified,
-                     uint32_t    kv_size,
-                     uint32_t    n_seq_max,
-                     uint32_t    n_pad,
-                     uint32_t    n_swa,
-               llama_swa_type    swa_type);
-
-    ~llama_kv_cache_unified() = default;
-
-    //
-    // llama_memory_i
-    //
-
-    llama_memory_context_ptr init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) override;
-
-    llama_memory_context_ptr init_full() override;
-
-    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
-
-    bool get_can_shift() const override;
-
-    void clear(bool data) override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    // state write/load
-
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
-
-    //
-    // llama_kv_cache_unified specific API
-    //
-
-    uint32_t get_size()     const;
-    uint32_t get_n_stream() const;
-
-    bool get_has_shift() const;
-
-    //
-    // graph_build API
-    //
-
-    uint32_t get_n_kv() const;
-
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
-    // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
-
-    // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
-
-    //
-    // preparation API
-    //
-
-    // find places for the provided ubatches in the cache, returns the slot infos
-    // return empty vector on failure
-    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
-
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
-
-    // find a slot of kv cells that can hold the ubatch
-    // if cont == true, then the slot must be continuous
-    // return empty slot_info on failure
-    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
-
-    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
-    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
-
-    //
-    // input API
-    //
-
-    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-
-    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
-    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
-
-    void set_input_k_shift(ggml_tensor * dst) const;
-
-    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-private:
-    const llama_model & model;
-    const llama_hparams & hparams;
-
-    struct kv_layer {
-        // layer index in the model
-        // note: can be different from the layer index in the KV cache
-        uint32_t il;
-
-        ggml_tensor * k;
-        ggml_tensor * v;
-
-        std::vector<ggml_tensor *> k_stream;
-        std::vector<ggml_tensor *> v_stream;
-    };
-
-    bool v_trans = true;  // the value tensor is transposed
-
-    const uint32_t n_seq_max = 1;
-    const uint32_t n_stream  = 1;
-
-    // required padding
-    const uint32_t n_pad = 1;
-
-    // SWA
-    const uint32_t n_swa = 0;
-
-    // env: LLAMA_KV_CACHE_DEBUG
-    int debug = 0;
-
-    // env: LLAMA_SET_ROWS (temporary)
-    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    bool supports_set_rows = true;
-
-    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
-
-    std::vector<ggml_context_ptr>        ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
-
-    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
-    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
-    std::vector<uint32_t> v_heads;
-
-    std::vector<llama_kv_cells_unified> v_cells;
-
-    // maps from a sequence id to a stream id
-    std::vector<uint32_t> seq_to_stream;
-
-    // pending stream copies that will be applied during the next update
-    stream_copy_info sc_info;
-
-    std::vector<kv_layer> layers;
-
-    // model layer id -> KV cache layer id
-    std::unordered_map<int32_t, int32_t> map_layer_ids;
-
-    // return non-empty vector if cells have been moved
-    defrag_info defrag_prepare(int32_t n_max_nodes) const;
-
-    size_t total_size() const;
-
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
-
-    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
-
-    ggml_tensor * build_rope_shift(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_tensor * cur,
-                    ggml_tensor * shift,
-                    ggml_tensor * factors,
-                          float   freq_base,
-                          float   freq_scale) const;
-
-    ggml_cgraph * build_graph_shift(
-               llm_graph_result * res,
-                  llama_context * lctx) const;
-
-    ggml_cgraph * build_graph_defrag(
-               llm_graph_result * res,
-                  llama_context * lctx,
-              const defrag_info & dinfo) const;
-
-    struct cell_ranges_t {
-        uint32_t strm;
-
-        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
-    };
-
-    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
-
-    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
-};
-
-class llama_kv_cache_unified_context : public llama_memory_context_i {
-public:
-    // some shorthands
-    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info      = llama_kv_cache_unified::defrag_info;
-    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
-
-    // used for errors
-    llama_kv_cache_unified_context(llama_memory_status status);
-
-    // used to create a full-cache context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv);
-
-    // used to create an update context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
-            llama_context * lctx,
-            bool do_shift,
-            defrag_info dinfo,
-            stream_copy_info sc_info);
-
-    // used to create a batch procesing context from a batch
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
-            slot_info_vec_t sinfos,
-            std::vector<llama_ubatch> ubatches);
-
-    virtual ~llama_kv_cache_unified_context();
-
-    //
-    // llama_memory_context_i
-    //
-
-    bool next()  override;
-    bool apply() override;
-
-    llama_memory_status  get_status() const override;
-    const llama_ubatch & get_ubatch() const override;
-
-    //
-    // llama_kv_cache_unified_context specific API
-    //
-
-    uint32_t get_n_kv() const;
-
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
-    // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
-
-    // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
-
-    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-
-    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-    void set_input_k_shift   (ggml_tensor * dst) const;
-    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-private:
-    llama_memory_status status;
-
-    llama_kv_cache_unified * kv;
-    llama_context * lctx;
-
-    //
-    // update context
-    //
-
-    bool do_shift = false;
-
-    defrag_info dinfo;
-
-    stream_copy_info sc_info;
-
-    //
-    // batch processing context
-    //
-
-    // the index of the cur ubatch to process
-    size_t i_cur = 0;
-
-    slot_info_vec_t sinfos;
-
-    std::vector<llama_ubatch> ubatches;
-
-    //
-    // data needed for building the compute graph for the current ubatch:
-    //
-
-    // a heuristic, to avoid attending the full cache if it is not yet utilized
-    // as the cache gets filled, the benefit from this heuristic disappears
-    int32_t n_kv;
-};
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
new file mode 100644 (file)
index 0000000..bb490cf
--- /dev/null
@@ -0,0 +1,2410 @@
+#include "llama-kv-cache.h"
+
+#include "llama-impl.h"
+#include "llama-io.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <limits>
+#include <map>
+#include <stdexcept>
+
+//
+// llama_kv_cache
+//
+
+llama_kv_cache::llama_kv_cache(
+        const llama_model &  model,
+          layer_filter_cb && filter,
+                ggml_type    type_k,
+                ggml_type    type_v,
+                     bool    v_trans,
+                     bool    offload,
+                     bool    unified,
+                 uint32_t    kv_size,
+                 uint32_t    n_seq_max,
+                 uint32_t    n_pad,
+                 uint32_t    n_swa,
+           llama_swa_type    swa_type) :
+    model(model), hparams(model.hparams), v_trans(v_trans),
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+
+    GGML_ASSERT(kv_size % n_pad == 0);
+
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    auto n_layer_cache = hparams.n_layer;
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        n_layer_cache = 20;
+    }
+    if (model.arch == LLM_ARCH_GLM4_MOE) {
+        // GLM-4.5: Only process up to last layer, skip final NextN layer
+        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
+    }
+
+    // create a context for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+
+            ctx_map[buft] = ctx;
+            ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
+
+    v_heads.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_heads[s] = 0;
+    }
+
+    v_cells.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].resize(kv_size);
+    }
+
+    // by default, all sequence ids are mapped to the 0th stream
+    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
+
+    if (n_stream > 1) {
+        seq_to_stream.resize(n_stream, 0);
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            seq_to_stream[s] = s;
+        }
+    }
+
+    // [TAG_V_CACHE_VARIABLE]
+    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
+        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
+                __func__, hparams.n_embd_v_gqa_max());
+    }
+
+    for (uint32_t il = 0; il < n_layer_cache; il++) {
+        if (filter && !filter(il)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+            continue;
+        }
+
+        // [TAG_V_CACHE_VARIABLE]
+        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
+
+        const char * dev_name = "CPU";
+
+        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
+
+        if (offload) {
+            auto * dev = model.dev_layer(il);
+            buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
+        }
+
+        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            throw std::runtime_error("failed to create ggml context for kv cache");
+        }
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+
+        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
+        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
+
+        ggml_format_name(k, "cache_k_l%d", il);
+        ggml_format_name(v, "cache_v_l%d", il);
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
+            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
+        }
+
+        map_layer_ids[il] = layers.size();
+
+        layers.push_back({ il, k, v, k_stream, v_stream, });
+    }
+
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
+
+        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
+            if (filter && !filter(il)) {
+                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+                continue;
+            }
+
+            const bool     is_swa   = hparams.is_swa(il);
+            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
+
+            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
+            map_layer_ids[il] = map_layer_ids[il_reuse];
+
+            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
+        }
+    }
+
+    // allocate tensors and initialize the buffers to avoid NaNs in the padding
+    for (auto it : ctx_map) {
+        auto * buft = it.first;
+        auto * ctx  = it.second;
+
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            throw std::runtime_error("failed to allocate buffer for kv cache");
+        }
+
+        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+
+        ggml_backend_buffer_clear(buf, 0);
+        bufs.emplace_back(buf);
+    }
+
+    {
+        const size_t memory_size_k = size_k_bytes();
+        const size_t memory_size_v = size_v_bytes();
+
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+    }
+
+    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
+    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
+
+    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
+    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
+
+    if (!supports_set_rows) {
+        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
+        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
+    }
+
+    if (!supports_set_rows) {
+        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
+    }
+}
+
+void llama_kv_cache::clear(bool data) {
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].reset();
+        v_heads[s] = 0;
+    }
+
+    if (data) {
+        for (auto & buf : bufs) {
+            ggml_backend_buffer_clear(buf.get(), 0);
+        }
+    }
+}
+
+bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if (seq_id >= 0) {
+        auto & cells = v_cells[seq_to_stream[seq_id]];
+        auto & head  = v_heads[seq_to_stream[seq_id]];
+
+        uint32_t new_head = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+        }
+
+        // If we freed up a slot, set head to it so searching can start there.
+        if (new_head != cells.size() && new_head < head) {
+            head = new_head;
+        }
+    } else {
+        // match any sequence
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+            auto & head  = v_heads[s];
+
+            uint32_t new_head = cells.size();
+
+            for (uint32_t i = 0; i < cells.size(); ++i) {
+                if (!cells.pos_in(i, p0, p1)) {
+                    continue;
+                }
+
+                cells.rm(i);
+
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+
+            // If we freed up a slot, set head to it so searching can start there.
+            if (new_head != cells.size() && new_head < head) {
+                head = new_head;
+            }
+        }
+    }
+
+    return true;
+}
+
+void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
+    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
+
+    const auto s0 = seq_to_stream[seq_id_src];
+    const auto s1 = seq_to_stream[seq_id_dst];
+
+    if (s0 == s1) {
+        // since both sequences are in the same stream, no data copy is necessary
+        // we just have to update the cells meta data
+
+        auto & cells = v_cells[s0];
+
+        if (seq_id_src == seq_id_dst) {
+            return;
+        }
+
+        if (p0 < 0) {
+            p0 = 0;
+        }
+
+        if (p1 < 0) {
+            p1 = std::numeric_limits<llama_pos>::max();
+        }
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id_src)) {
+                cells.seq_add(i, seq_id_dst);
+            }
+        }
+
+        return;
+    }
+
+    // cross-stream sequence copies require to copy the actual buffer data
+
+    bool is_full = true;
+
+    if (p0 > 0 && p0 + 1 < (int) get_size()) {
+        is_full = false;
+    }
+
+    if (p1 > 0 && p1 + 1 < (int) get_size()) {
+        is_full = false;
+    }
+
+    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
+
+    // enqueue the copy operation - the buffer copy will be performed during the next update
+    sc_info.ssrc.push_back(s0);
+    sc_info.sdst.push_back(s1);
+
+    v_cells[s1].reset();
+    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
+        if (v_cells[s0].seq_has(i, seq_id_src)) {
+            llama_pos pos   = v_cells[s0].pos_get(i);
+            llama_pos shift = v_cells[s0].get_shift(i);
+
+            if (shift != 0) {
+                pos -= shift;
+                assert(pos >= 0);
+            }
+
+            v_cells[s1].pos_set(i, pos);
+            v_cells[s1].seq_add(i, seq_id_dst);
+
+            if (shift != 0) {
+                v_cells[s1].pos_add(i, shift);
+            }
+        }
+    }
+
+    v_heads[s1] = v_heads[s0];
+
+    //for (uint32_t s = 0; s < n_stream; ++s) {
+    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
+    //}
+}
+
+void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
+    uint32_t new_head = cells.size();
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_keep(i, seq_id)) {
+            if (new_head == cells.size()) {
+                new_head = i;
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cells.size() && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
+    if (shift == 0) {
+        return;
+    }
+
+    uint32_t new_head = cells.size();
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over all cells.
+    if (p0 == p1) {
+        return;
+    }
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
+
+        if (cells.seq_has(i, seq_id)) {
+            if (cells.pos_add(i, shift)) {
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    // Otherwise we just start the next search from the beginning.
+    head = new_head != cells.size() ? new_head : 0;
+}
+
+void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
+
+        if (cells.seq_has(i, seq_id)) {
+            cells.pos_div(i, d);
+        }
+    }
+}
+
+llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    return cells.seq_pos_min(seq_id);
+}
+
+llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    return cells.seq_pos_max(seq_id);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) {
+    GGML_UNUSED(embd_all);
+
+    do {
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos = prepare(ubatches);
+        if (sinfos.empty()) {
+            break;
+        }
+
+        return std::make_unique<llama_kv_cache_context>(
+                this, std::move(sinfos), std::move(ubatches));
+    } while (false);
+
+    return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_full() {
+    return std::make_unique<llama_kv_cache_context>(this);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
+    bool do_shift = get_has_shift();
+
+    defrag_info dinfo;
+
+    // see if we need to defrag
+    if (n_stream == 1) {
+        // note : for now do not consider defrag for n_stream > 1
+        const auto & cells = v_cells[seq_to_stream[0]];
+
+        bool do_defrag = optimize;
+
+        const auto thold = lctx->get_cparams().defrag_thold;
+
+        if (!do_defrag && thold > 0.0f) {
+            const auto n_kv = cells.used_max_p1();
+
+            // - do not defrag small contexts (i.e. < 2048 tokens)
+            // - count the padding towards the number of used tokens
+            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+
+            if (fragmentation > thold) {
+                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+
+                do_defrag = true;
+            }
+        }
+
+        if (do_defrag) {
+            dinfo = defrag_prepare(lctx->graph_max_nodes());
+        }
+    }
+
+    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
+}
+
+llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache::slot_info_vec_t res;
+
+    struct state_t {
+        slot_info sinfo; // slot info for the ubatch
+
+        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
+
+        std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
+    };
+
+    // remember the old state of the cells so we can restore it in the end
+    std::vector<state_t> states;
+
+    bool success = true;
+
+    for (const auto & ubatch : ubatches) {
+        // non-continuous slots require support for ggml_set_rows()
+        const bool cont = supports_set_rows ? false : true;
+
+        // only find a suitable slot for the ubatch. don't modify the cells yet
+        const auto sinfo_new = find_slot(ubatch, cont);
+        if (sinfo_new.empty()) {
+            success = false;
+            break;
+        }
+
+        // remeber the position that we found
+        res.push_back(sinfo_new);
+
+        // store the old state of the cells in the recovery stack
+        {
+            state_t state = { sinfo_new, v_heads, {} };
+
+            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
+                auto & cells = v_cells[sinfo_new.strm[s]];
+
+                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
+            }
+
+            states.push_back(std::move(state));
+        }
+
+        // now emplace the ubatch
+        apply_ubatch(sinfo_new, ubatch);
+    }
+
+    GGML_ASSERT(!states.empty() || !success);
+
+    // iterate backwards and restore the cells to their original state
+    for (auto it = states.rbegin(); it != states.rend(); ++it) {
+        const auto & sinfo = it->sinfo;
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            auto & cells = v_cells[sinfo.strm[s]];
+            auto & head  = v_heads[sinfo.strm[s]];
+
+            cells.set(sinfo.idxs[s], it->v_cells[s]);
+            head = it->v_heads_old[s];
+        }
+    }
+
+    if (!success) {
+        return {};
+    }
+
+    return res;
+}
+
+bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
+    bool updated = false;
+
+    auto * sched = lctx->get_sched();
+
+    if (!sc_info.empty()) {
+        assert(n_stream > 1 && "stream copy should never happen with a single stream");
+
+        llama_synchronize(lctx);
+
+        const size_t n_copy = sc_info.ssrc.size();
+
+        for (size_t i = 0; i < n_copy; ++i) {
+            const auto ssrc = sc_info.ssrc[i];
+            const auto sdst = sc_info.sdst[i];
+
+            assert(ssrc < n_stream);
+            assert(sdst < n_stream);
+
+            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
+
+            assert(ssrc != sdst);
+
+            for (uint32_t il = 0; il < layers.size(); ++il) {
+                const auto & layer = layers[il];
+
+                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
+                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+            }
+        }
+    }
+
+    if (do_shift) {
+        if (!get_can_shift()) {
+            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
+        }
+
+        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
+
+        // apply K-shift if needed
+        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
+            ggml_backend_sched_reset(sched);
+
+            auto * res = lctx->get_gf_res_reserve();
+
+            res->reset();
+
+            auto * gf = build_graph_shift(res, lctx);
+            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
+                return updated;
+            }
+
+            res->set_inputs(nullptr);
+
+            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
+                return updated;
+            }
+
+            updated = true;
+        }
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+
+            cells.reset_shift();
+        }
+    }
+
+    if (!dinfo.empty()) {
+        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
+
+        // note: for now do not consider defrag for n_stream > 1
+        auto & cells = v_cells[seq_to_stream[0]];
+        auto & head  = v_heads[seq_to_stream[0]];
+
+        // apply moves:
+        {
+            const auto n_kv = dinfo.ids.size();
+
+            for (uint32_t i = 0; i < n_kv; ++i) {
+                assert(dinfo.ids[i] <= n_kv);
+
+                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
+                    continue;
+                }
+
+                cells.mv(i, dinfo.ids[i]);
+            }
+
+            // reset the head so we can find the first free slot during the next ubatch
+            head = 0;
+        }
+
+        ggml_backend_sched_reset(sched);
+
+        auto * res = lctx->get_gf_res_reserve();
+
+        res->reset();
+
+        auto * gf = build_graph_defrag(res, lctx, dinfo);
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
+            return updated;
+        }
+
+        res->set_inputs(nullptr);
+
+        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
+            return updated;
+        }
+
+        updated = true;
+    }
+
+    return updated;
+}
+
+llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
+
+    if (debug > 0) {
+        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+            const auto seq_id = ubatch.seq_id_unq[s];
+            const auto stream_id = seq_to_stream[seq_id];
+            const auto & cells = v_cells[stream_id];
+            const uint32_t head_cur = v_heads[stream_id];
+
+            LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
+                    __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
+
+            if ((debug == 2 && n_swa > 0) || debug > 2) {
+                std::string ss;
+                for (uint32_t i = 0; i < cells.size(); ++i) {
+                    if (cells.is_empty(i)) {
+                        ss += '.';
+                    } else {
+                        assert(cells.seq_count(i) >= 1);
+
+                        if (cells.seq_count(i) == 1) {
+                            ss += std::to_string(cells.seq_get(i));
+                        } else {
+                            ss += 'M';
+                        }
+                    }
+                    if (i%256 == 255) {
+                        ss += " *";
+                        ss += '\n';
+                    }
+                }
+                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
+            }
+
+            if ((debug == 2 && n_swa > 0) || debug > 2) {
+                std::string ss;
+                for (uint32_t i = 0; i < cells.size(); ++i) {
+                    std::string cur;
+                    if (cells.is_empty(i)) {
+                        cur = '.';
+                    } else {
+                        cur = std::to_string(cells.pos_get(i));
+                    }
+                    const int n = cur.size();
+                    for (int j = 0; j < 5 - n; ++j) {
+                        cur += ' ';
+                    }
+                    ss += cur;
+                    if (i%256 == 255) {
+                        ss += " *";
+                    }
+                    if (i%64 == 63) {
+                        ss += '\n';
+                    }
+                }
+                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
+            }
+
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+                if (cells.seq_pos_min(s) < 0) {
+                    continue;
+                }
+
+                LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
+            }
+        }
+    }
+
+    uint32_t n_tokens = ubatch.n_tokens;
+    uint32_t n_seqs   = 1;
+
+    if (n_stream > 1) {
+        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
+
+        n_seqs   = ubatch.n_seqs_unq;
+        n_tokens = n_tokens / n_seqs;
+    }
+
+    slot_info res = {
+        /*.s0   =*/ LLAMA_MAX_SEQ,
+        /*.s1   =*/ 0,
+        /*.strm =*/ { },
+        /*.idxs =*/ { },
+    };
+
+    res.resize(n_seqs);
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const auto seq_id = ubatch.seq_id_unq[s];
+
+        if (n_stream > 1) {
+            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
+            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
+        }
+
+        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+
+        res.strm[s] = seq_to_stream[seq_id];
+        res.idxs[s].reserve(n_tokens);
+
+        const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (head_cur > cells.get_used() + 2*n_tokens) {
+            head_cur = 0;
+        }
+
+        if (n_tokens > cells.size()) {
+            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
+            return { };
+        }
+
+        uint32_t n_tested = 0;
+
+        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
+        // for non-continuous slots, we test the tokens one by one
+        const uint32_t n_test = cont ? n_tokens : 1;
+
+        while (true) {
+            if (head_cur + n_test > cells.size()) {
+                n_tested += cells.size() - head_cur;
+                head_cur = 0;
+                continue;
+            }
+
+            for (uint32_t i = 0; i < n_test; i++) {
+                const auto idx = head_cur;
+
+                head_cur++;
+                n_tested++;
+
+                //const llama_pos    pos    = ubatch.pos[i];
+                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
+
+                // can we use this cell? either:
+                //  - the cell is empty
+                //  - the cell is occupied only by one sequence:
+                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
+                //    - mask SWA, using current max pos for that sequence in the cache
+                //                always insert in the cell with minimum pos
+                bool can_use = cells.is_empty(idx);
+
+                if (!can_use && cells.seq_count(idx) == 1) {
+                    const llama_pos pos_cell = cells.pos_get(idx);
+
+                    // (disabled) causal mask
+                    // note: it's better to purge any "future" tokens beforehand
+                    //if (cells.seq_has(idx, seq_id)) {
+                    //    can_use = pos_cell >= pos;
+                    //}
+
+                    if (!can_use) {
+                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
+
+                        // SWA mask
+                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                            can_use = true;
+                        }
+                    }
+                }
+
+                if (can_use) {
+                    res.idxs[s].push_back(idx);
+                } else {
+                    if (cont) {
+                        break;
+                    }
+                }
+            }
+
+            if (res.idxs[s].size() == n_tokens) {
+                break;
+            }
+
+            if (cont) {
+                res.idxs[s].clear();
+            }
+
+            if (n_tested >= cells.size()) {
+                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+                return { };
+            }
+        }
+
+        // we didn't find a suitable slot - return empty result
+        if (res.idxs[s].size() < n_tokens) {
+            return { };
+        }
+    }
+
+    assert(res.s1 >= res.s0);
+
+    return res;
+}
+
+void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
+    // keep track of the max sequence position that we would overwrite with this ubatch
+    // for non-SWA cache, this would be always empty
+    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        seq_pos_max_rm[s] = -1;
+    }
+
+    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
+
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
+            const uint32_t i = s*sinfo.size() + ii;
+
+            auto & cells = v_cells[sinfo.strm[s]];
+
+            const auto idx = sinfo.idxs[s][ii];
+
+            if (!cells.is_empty(idx)) {
+                assert(cells.seq_count(idx) == 1);
+
+                const llama_seq_id seq_id = cells.seq_get(idx);
+                const llama_pos    pos    = cells.pos_get(idx);
+
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
+                cells.rm(idx);
+            }
+
+            cells.pos_set(idx, ubatch.pos[i]);
+
+            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
+                cells.seq_add(idx, ubatch.seq_id[i][s]);
+            }
+        }
+    }
+
+    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
+    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        if (seq_pos_max_rm[s] == -1) {
+            continue;
+        }
+
+        GGML_ASSERT(s < seq_to_stream.size());
+
+        auto & cells = v_cells[seq_to_stream[s]];
+
+        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
+            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
+                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
+
+            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
+        }
+    }
+
+    // move the head at the end of the slot
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        auto & head = v_heads[sinfo.strm[s]];
+
+        head = sinfo.idxs[s].back() + 1;
+    }
+}
+
+bool llama_kv_cache::get_can_shift() const {
+    return true;
+}
+
+uint32_t llama_kv_cache::get_size() const {
+    const auto & cells = v_cells[seq_to_stream[0]];
+
+    return cells.size();
+}
+
+uint32_t llama_kv_cache::get_n_stream() const {
+    return n_stream;
+}
+
+bool llama_kv_cache::get_has_shift() const {
+    bool result = false;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        result |= v_cells[s].get_has_shift();
+    }
+
+    return result;
+}
+
+uint32_t llama_kv_cache::get_n_kv() const {
+    uint32_t result = 0;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
+    }
+
+    return result;
+}
+
+bool llama_kv_cache::get_supports_set_rows() const {
+    return supports_set_rows;
+}
+
+ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * k = layers[ikv].k;
+
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_k_gqa = k->ne[0];
+
+    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    return ggml_view_4d(ctx, k,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
+            ggml_row_size(k->type, hparams.n_embd_head_k),
+            ggml_row_size(k->type, n_embd_k_gqa),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
+}
+
+ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
+
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_v_gqa = v->ne[0];
+
+    // [TAG_V_CACHE_VARIABLE]
+    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    if (!v_trans) {
+        // note: v->nb[1] <= v->nb[2]
+        return ggml_view_4d(ctx, v,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
+                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
+    }
+
+    // note: v->nb[1] > v->nb[2]
+    return ggml_view_4d(ctx, v,
+            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
+            ggml_row_size(v->type, kv_size),                          // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
+}
+
+ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * k = layers[ikv].k;
+
+    const int64_t n_embd_k_gqa = k->ne[0];
+    const int64_t n_tokens = k_cur->ne[2];
+
+    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
+
+    if (k_idxs && supports_set_rows) {
+        if (k->ne[2] > 1) {
+            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
+        }
+
+        return ggml_set_rows(ctx, k, k_cur, k_idxs);
+    }
+
+    // TODO: fallback to old ggml_cpy() method for backwards compatibility
+    //       will be removed when ggml_set_rows() is adopted by all backends
+
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
+    ggml_tensor * k_view = ggml_view_1d(ctx, k,
+            n_tokens*n_embd_k_gqa,
+            ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
+
+    return ggml_cpy(ctx, k_cur, k_view);
+}
+
+ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
+
+    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
+    const int64_t n_tokens     = v_cur->ne[2];
+
+    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
+
+    if (v_idxs && supports_set_rows) {
+        if (!v_trans) {
+            if (v->ne[2] > 1) {
+                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
+            }
+
+            return ggml_set_rows(ctx, v, v_cur, v_idxs);
+        }
+
+        // [TAG_V_CACHE_VARIABLE]
+        if (n_embd_v_gqa < v->ne[0]) {
+            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+        }
+
+        // the row becomes a single element
+        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
+
+        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
+
+        return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
+    }
+
+    // TODO: fallback to old ggml_cpy() method for backwards compatibility
+    //       will be removed when ggml_set_rows() is adopted by all backends
+
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
+    ggml_tensor * v_view = nullptr;
+
+    if (!v_trans) {
+        v_view = ggml_view_1d(ctx, v,
+                n_tokens*n_embd_v_gqa,
+                ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
+    } else {
+        v_cur = ggml_transpose(ctx, v_cur);
+
+        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
+                (v->ne[1]    )*ggml_element_size(v),
+                (sinfo.head())*ggml_element_size(v));
+    }
+
+    return ggml_cpy(ctx, v_cur, v_view);
+}
+
+ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+
+    ggml_set_input(k_idxs);
+
+    return k_idxs;
+}
+
+ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * v_idxs;
+
+    if (!v_trans) {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    } else {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
+    }
+
+    ggml_set_input(v_idxs);
+
+    return v_idxs;
+}
+
+void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    if (!supports_set_rows) {
+        return;
+    }
+
+    const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const int64_t offs = sinfo.strm[s]*get_size();
+
+        for (uint32_t i = 0; i < sinfo.size(); ++i) {
+            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+        }
+    }
+}
+
+void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    if (!supports_set_rows) {
+        return;
+    }
+
+    const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    if (!v_trans) {
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*get_size();
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+            }
+        }
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        const int64_t kv_size = get_size();
+
+        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    int32_t * data = (int32_t *) dst->data;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
+        }
+    }
+}
+
+void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    const uint32_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    float * data = (float *) dst->data;
+
+    const int64_t n_kv     = dst->ne[0];
+    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+
+    GGML_ASSERT(n_tokens%n_stream == 0);
+
+    // n_tps == n_tokens_per_stream
+    const int64_t n_tps     = n_tokens/n_stream;
+    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
+
+    std::fill(data, data + ggml_nelements(dst), -INFINITY);
+
+    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
+    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
+    //   Causal mask:
+    //      xxx-------
+    //      xxxx------
+    //      xxxxx-----
+    //   Non-causal mask:
+    //      xxxxx-----
+    //      xxxxx-----
+    //      xxxxx-----
+    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+    // TODO: optimize this section
+    for (uint32_t h = 0; h < 1; ++h) {
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            for (uint32_t ii = 0; ii < n_tps; ++ii) {
+                const uint32_t i = s*n_tps + ii;
+
+                const llama_seq_id seq_id = ubatch->seq_id[i][0];
+
+                const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+                const llama_pos p1 = ubatch->pos[i];
+
+                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
+
+                for (uint32_t j = 0; j < n_kv; ++j) {
+                    if (cells.is_empty(j)) {
+                        continue;
+                    }
+
+                    // mask the token if not the same sequence
+                    if (!cells.seq_has(j, seq_id)) {
+                        continue;
+                    }
+
+                    const llama_pos p0 = cells.pos_get(j);
+
+                    // mask future tokens
+                    if (causal_attn && p0 > p1) {
+                        continue;
+                    }
+
+                    // apply SWA if any
+                    if (is_masked_swa(p0, p1)) {
+                        continue;
+                    }
+
+                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    const int64_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
+    const auto & cells = v_cells[0];
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
+
+    int32_t * data = (int32_t *) dst->data;
+
+    const int32_t n_kv = dst->ne[0];
+
+    for (int h = 0; h < 1; ++h) {
+        for (int i = 0; i < n_tokens; ++i) {
+            for (int j = 0; j < n_kv; ++j) {
+                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
+                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
+
+                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
+            }
+        }
+    }
+}
+
+size_t llama_kv_cache::total_size() const {
+    size_t size = 0;
+
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+size_t llama_kv_cache::size_k_bytes() const {
+    size_t size_k_bytes = 0;
+
+    for (const auto & layer : layers) {
+        size_k_bytes += ggml_nbytes(layer.k);
+    }
+
+    return size_k_bytes;
+}
+
+size_t llama_kv_cache::size_v_bytes() const {
+    size_t size_v_bytes = 0;
+
+    for (const auto & layer : layers) {
+        size_v_bytes += ggml_nbytes(layer.v);
+    }
+
+    return size_v_bytes;
+}
+
+ggml_tensor * llama_kv_cache::build_rope_shift(
+        const llama_cparams & cparams,
+               ggml_context * ctx,
+                ggml_tensor * cur,
+                ggml_tensor * shift,
+                ggml_tensor * factors,
+                      float   freq_base,
+                      float   freq_scale) const {
+    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
+
+    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
+    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
+    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
+
+    const auto & n_rot     = hparams.n_rot;
+    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
+                                // @ngxson : this is a workaround
+                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
+                                // a normal RoPE should work, we just need to use the correct ordering
+                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
+                                ? LLAMA_ROPE_TYPE_NEOX
+                                : hparams.rope_type;
+
+    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
+    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
+                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
+                                    : cparams.yarn_attn_factor;
+
+    ggml_tensor * tmp;
+
+    if (ggml_is_quantized(cur->type)) {
+        // dequantize to f32 -> RoPE -> quantize back
+        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
+
+        tmp = ggml_rope_ext(ctx, tmp,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+
+        tmp = ggml_cpy(ctx, tmp, cur);
+    } else {
+        // we rotate only the first n_rot dimensions
+        tmp = ggml_rope_ext_inplace(ctx, cur,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+    }
+
+    return tmp;
+}
+
+class llm_graph_input_k_shift : public llm_graph_input_i {
+public:
+    llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_k_shift() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
+
+    const llama_kv_cache * kv_self;
+};
+
+void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (k_shift) {
+        kv_self->set_input_k_shift(k_shift);
+    }
+}
+
+ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
+
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
+
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
+    ggml_set_input(inp->k_shift);
+
+    const auto & cparams = lctx->get_cparams();
+
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const int64_t n_head_kv    = hparams.n_head_kv(il);
+        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+        ggml_tensor * k =
+            ggml_view_3d(ctx, layer.k,
+                n_embd_head_k, n_head_kv, get_size()*n_stream,
+                ggml_row_size(layer.k->type, n_embd_head_k),
+                ggml_row_size(layer.k->type, n_embd_k_gqa),
+                0);
+
+        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    res->add_input(std::move(inp));
+
+    return gf;
+}
+
+ggml_cgraph * llama_kv_cache::build_graph_defrag(
+         llm_graph_result * res,
+            llama_context * lctx,
+        const defrag_info & dinfo) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
+
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
+
+    const auto & ids = dinfo.ids;
+
+    const auto & cparams = lctx->get_cparams();
+
+#if 0
+    // CPU defrag
+    //
+    // TODO: optimizations are possible:
+    //       - multiple threads
+    //       - avoid copying to the host memory when already there
+    //
+    // likely not worth the effort, as we have ggml_graph based defrag
+    //
+
+    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+    const uint32_t kv_size = size;
+
+    std::vector<uint8_t> buf_k;
+    std::vector<uint8_t> buf_v;
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
+
+        const size_t v_size_el = ggml_type_size(v_l[il]->type);
+        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
+
+        buf_k.resize(k_size);
+        buf_v.resize(v_size);
+
+        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
+
+        // batch move [i, i+nm) to [id, id+nm)
+        // note: cells can move only to a lower index
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t id = ids[i];
+
+            if (i == id || id == n_kv) {
+                continue;
+            }
+
+            uint32_t nm = 1;
+
+            while (i + nm < n_kv && ids[i + nm] == id + nm) {
+                nm++;
+            }
+
+            // move keys
+            {
+                const int64_t os =  i*k_size_row;
+                const int64_t od = id*k_size_row;
+
+                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
+            }
+
+            // move values (note: they are transposed)
+            {
+                const int64_t os =  i;
+                const int64_t od = id;
+
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
+                }
+            }
+
+            i += nm - 1;
+        }
+
+        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
+    }
+#else
+    for (uint32_t i = 0; i < ids.size(); ++i) {
+        const uint32_t id = ids[i];
+
+        if (i == id || id == ids.size()) {
+            continue;
+        }
+
+        uint32_t nm = 1;
+
+        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
+            nm++;
+        }
+
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(layer.k->type, n_embd_k_gqa),
+                    ggml_row_size(layer.k->type, n_embd_k_gqa*i));
+
+            ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(layer.k->type, n_embd_k_gqa),
+                    ggml_row_size(layer.k->type, n_embd_k_gqa*id));
+
+            ggml_tensor * view_v_src;
+            ggml_tensor * view_v_dst;
+
+            if (cparams.flash_attn) {
+                // NOTE: the V cache is not transposed when using flash attention
+                view_v_src = ggml_view_2d(ctx, layer.v,
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(layer.v->type, n_embd_v_gqa),
+                        ggml_row_size(layer.v->type, n_embd_v_gqa*i));
+
+                view_v_dst = ggml_view_2d(ctx, layer.v,
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(layer.v->type, n_embd_v_gqa),
+                        ggml_row_size(layer.v->type, n_embd_v_gqa*id));
+            } else {
+                view_v_src = ggml_view_2d(ctx, layer.v,
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(layer.v->type, cells.size()),
+                        ggml_row_size(layer.v->type, i));
+
+                view_v_dst = ggml_view_2d(ctx, layer.v,
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(layer.v->type, cells.size()),
+                        ggml_row_size(layer.v->type, id));
+            }
+
+            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
+            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
+        }
+
+        i += nm - 1;
+    }
+
+    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
+#endif
+
+    return gf;
+}
+
+llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
+
+    const uint32_t n_layer = layers.size();
+
+    const uint32_t n_kv   = cells.used_max_p1();
+    const uint32_t n_used = cells.get_used();
+
+    assert(n_used <= n_kv);
+
+    //const int64_t t_start = ggml_time_us();
+
+    // number of cells moved
+    uint32_t n_moves = 0;
+
+    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
+    //   - source view, destination view, copy operation
+    //   - x2 for keys and values
+    //const uint32_t max_moves = max_nodes()/(6*n_layer);
+    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+
+    // determine which KV cells to move where
+    defrag_info res;
+    auto & ids = res.ids;
+
+    ids.resize(n_kv, n_kv);
+
+    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
+        if (!cells.is_empty(i0)) {
+            ids[i0] = i0;
+
+            continue;
+        }
+
+        // found a hole - fill it with data from the end of the cache
+
+        uint32_t nh = 1;
+
+        // determine the size of the hole
+        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
+            nh++;
+        }
+
+        uint32_t nf = 0;
+        uint32_t is = n_kv - 1;
+
+        // starting from the end, find nh non-empty cells
+        for (; is > i0; --is) {
+            if (cells.is_empty(is) || ids[is] != n_kv) {
+                continue;
+            }
+
+            // non-empty cell which is not yet moved
+            nf++;
+
+            if (nf == nh) {
+                break;
+            }
+        }
+
+        // this can only happen if `n_used` is not accurate, which would be a bug
+        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
+
+        nf = 0;
+
+        uint32_t i1 = is;
+
+        // are we moving a continuous block of memory?
+        bool cont = false;
+
+        // should we stop searching for the next move?
+        bool stop = false;
+
+        // go back and move the nf cells to the hole
+        for (; i1 < n_kv; ++i1) {
+            if (cells.is_empty(i1) || ids[i1] != n_kv) {
+                if (n_moves == max_moves) {
+                    stop = true;
+                    break;
+                }
+
+                cont = false;
+                continue;
+            }
+
+            // this cell goes to (i0 + nf)
+            ids[i1] = i0 + nf;
+
+            if (!cont) {
+                n_moves++;
+                cont = true;
+            }
+
+            nf++;
+
+            if (nf == nh) {
+                break;
+            }
+        }
+
+        if (stop || n_moves == max_moves) {
+            break;
+        }
+
+        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
+
+        i0 += nh - 1;
+    }
+
+    if (n_moves == 0) {
+        return {};
+    }
+
+    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
+
+    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
+
+    return res;
+}
+
+bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
+    assert(p0 >= 0 && p1 >= 0);
+
+    switch (swa_type) {
+        case LLAMA_SWA_TYPE_NONE:
+            {
+            } break;
+        case LLAMA_SWA_TYPE_STANDARD:
+            {
+                if (p1 - p0 >= (int32_t) n_swa) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_CHUNKED:
+            {
+                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+                if (p0 < pos_chunk_start) {
+                    return true;
+                }
+            } break;
+    }
+
+    return false;
+}
+
+void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    GGML_UNUSED(flags);
+
+    io.write(&n_stream, sizeof(n_stream));
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        cell_ranges_t cr { s, {} };
+
+        uint32_t cell_count = 0;
+
+        const auto & cells = v_cells[s];
+
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
+                ++cell_count;
+                if (cell_range_begin == cells.size()) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != cells.size()) {
+                    cr.data.emplace_back(cell_range_begin, i);
+                    cell_range_begin = cells.size();
+                }
+            }
+        }
+
+        if (cell_range_begin != cells.size()) {
+            cr.data.emplace_back(cell_range_begin, cells.size());
+        }
+
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cr.data) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
+
+        io.write(&cell_count, sizeof(cell_count));
+
+        // skip empty streams
+        if (cell_count == 0) {
+            continue;
+        }
+
+        state_write_meta(io, cr, seq_id);
+        state_write_data(io, cr);
+    }
+}
+
+void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    GGML_UNUSED(flags);
+
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+
+    uint32_t n_stream_cur;
+    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
+    if (n_stream_cur != n_stream) {
+        throw std::runtime_error("n_stream mismatch");
+    }
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        uint32_t cell_count;
+        io.read_to(&cell_count, sizeof(cell_count));
+
+        if (cell_count == 0) {
+            continue;
+        }
+
+        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
+
+        bool res = true;
+        res = res && state_read_meta(io, strm, cell_count, seq_id);
+        res = res && state_read_data(io, strm, cell_count);
+
+        if (!res) {
+            if (seq_id == -1) {
+                clear(true);
+            } else {
+                seq_rm(seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
+        }
+    }
+}
+
+void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
+    const auto & cells = v_cells[cr.strm];
+
+    for (const auto & range : cr.data) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            std::vector<llama_seq_id> seq_ids;
+
+            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
+                if (cur == seq_id || seq_id == -1) {
+                    if (cells.seq_has(i, cur)) {
+                        seq_ids.push_back(cur);
+                    }
+                }
+            }
+
+            const llama_pos pos     = cells.pos_get(i);
+            const uint32_t n_seq_id = seq_ids.size();
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
+
+            for (const auto & seq_id : seq_ids) {
+                io.write(&seq_id, sizeof(seq_id));
+            }
+        }
+    }
+}
+
+void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
+    const auto & cells = v_cells[cr.strm];
+
+    const uint32_t v_trans = this->v_trans ? 1 : 0;
+    const uint32_t n_layer = layers.size();
+
+    io.write(&v_trans, sizeof(v_trans));
+    io.write(&n_layer, sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        auto * k = layer.k_stream[cr.strm];
+
+        // Write key type
+        const int32_t k_type_i = (int32_t) k->type;
+        io.write(&k_type_i, sizeof(k_type_i));
+
+        // Write row size of key
+        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
+        io.write(&k_size_row, sizeof(k_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cr.data) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * k_size_row;
+            io.write_tensor(k, range.first * k_size_row, buf_size);
+        }
+    }
+
+    if (!v_trans) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[cr.strm];
+
+            // Write value type
+            const int32_t v_type_i = (int32_t) v->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write row size of value
+            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
+            io.write(&v_size_row, sizeof(v_size_row));
+
+            // Read each range of cells of v_size length each into tmp_buf and write out
+            for (const auto & range : cr.data) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * v_size_row;
+                io.write_tensor(v, range.first * v_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = cells.size();
+
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[cr.strm];
+
+            // Write value type
+            const int32_t v_type_i = (int32_t) v->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const uint32_t v_size_el = ggml_type_size(v->type);
+            io.write(&v_size_el, sizeof(v_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cr.data) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    const size_t buf_size = range_size * v_size_el;
+                    io.write_tensor(v, src_offset, buf_size);
+                }
+            }
+        }
+    }
+}
+
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
+    if (dest_seq_id != -1) {
+        // single sequence
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_batch_allocr balloc(hparams.n_pos_per_embd());
+
+        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
+
+        ubatch.seq_id_unq[0] = dest_seq_id;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 1) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
+
+            // read the sequence id, but directly discard it - we will use dest_seq_id instead
+            {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+            }
+
+            ubatch.pos[i]      = pos;
+            ubatch.n_seq_id[i] = n_seq_id;
+            ubatch.seq_id[i]   = &dest_seq_id;
+        }
+
+        const auto sinfo = find_slot(ubatch, true);
+        if (sinfo.empty()) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+
+        apply_ubatch(sinfo, ubatch);
+
+        const auto head_cur = sinfo.head();
+
+        // keep the head at the old position because we will read the KV data into it in state_read_data()
+        head = head_cur;
+
+        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+
+        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head_cur + cell_count <= cells.size());
+        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
+        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > cells.size()) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear(true);
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cells.pos_set(i, pos);
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
+                    return false;
+                }
+
+                cells.seq_add(i, seq_id);
+            }
+        }
+
+        head = 0;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
+    uint32_t v_trans;
+    uint32_t n_layer;
+
+    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
+
+    if (n_layer != layers.size()) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
+        return false;
+    }
+
+    if (cell_count > cells.size()) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
+        return false;
+    }
+
+    if (this->v_trans != (bool) v_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        auto * k = layer.k_stream[strm];
+
+        // Read type of key
+        int32_t k_type_i_ref;
+        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+        const int32_t k_type_i = (int32_t) k->type;
+        if (k_type_i != k_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t k_size_row_ref;
+        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
+        if (k_size_row != k_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+        }
+    }
+
+    if (!this->v_trans) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[strm];
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t) v->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t v_size_row_ref;
+            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[strm];
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t) v->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t v_size_el_ref;
+            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+            const size_t v_size_el = ggml_type_size(v->type);
+            if (v_size_el != v_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                return false;
+            }
+
+            // Read GQA embedding size
+            uint32_t n_embd_v_gqa_ref;
+            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
+                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+//
+// llama_kv_cache_context
+//
+
+llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+    n_kv = kv->get_size();
+
+    const uint32_t n_stream = kv->get_n_stream();
+
+    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
+    sinfos.resize(1);
+    sinfos[0].s0 = 0;
+    sinfos[0].s1 = n_stream - 1;
+    sinfos[0].idxs.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        sinfos[0].strm.push_back(s);
+        sinfos[0].idxs[s].resize(1, 0);
+    }
+}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
+        llama_context * lctx,
+        bool do_shift,
+        defrag_info dinfo,
+        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
+    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
+        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
+    }
+}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
+        llama_kv_cache::slot_info_vec_t sinfos,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
+}
+
+llama_kv_cache_context::~llama_kv_cache_context() = default;
+
+bool llama_kv_cache_context::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    if (++i_cur >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    // no ubatches -> this is a KV cache update
+    if (ubatches.empty()) {
+        kv->update(lctx, do_shift, dinfo, sc_info);
+
+        return true;
+    }
+
+    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
+
+    n_kv = kv->get_n_kv();
+
+    return true;
+}
+
+llama_memory_status llama_kv_cache_context::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_cur];
+}
+
+uint32_t llama_kv_cache_context::get_n_kv() const {
+    return n_kv;
+}
+
+bool llama_kv_cache_context::get_supports_set_rows() const {
+    return kv->get_supports_set_rows();
+}
+
+ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
+    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
+    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
+    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
+    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_k_idxs(ctx, ubatch);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_v_idxs(ctx, ubatch);
+}
+
+void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
+    kv->set_input_k_shift(dst);
+}
+
+void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
+void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
+void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    kv->set_input_kq_mask(dst, ubatch, causal_attn);
+}
+
+void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_pos_bucket(dst, ubatch);
+}
+
+uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
new file mode 100644 (file)
index 0000000..5ca618e
--- /dev/null
@@ -0,0 +1,399 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cells.h"
+#include "llama-memory.h"
+
+#include <unordered_map>
+#include <vector>
+
+struct llama_cparams;
+struct llama_hparams;
+struct llama_model;
+struct llama_context;
+
+//
+// llama_kv_cache
+//
+
+class llama_kv_cache : public llama_memory_i {
+public:
+    static uint32_t get_padding(const llama_cparams & cparams);
+
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    struct defrag_info {
+        bool empty() const {
+            return ids.empty();
+        }
+
+        // contains information about which cell moves where:
+        //  - cell i moves to ids[i]
+        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
+        std::vector<uint32_t> ids;
+    };
+
+    struct stream_copy_info {
+        bool empty() const {
+            assert(ssrc.size() == sdst.size());
+            return ssrc.empty();
+        }
+
+        std::vector<uint32_t> ssrc;
+        std::vector<uint32_t> sdst;
+    };
+
+    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
+    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
+    struct slot_info {
+        // data for ggml_set_rows
+        using idx_vec_t = std::vector<uint32_t>;
+
+        // number of streams: ns = s1 - s0 + 1
+        llama_seq_id s0;
+        llama_seq_id s1;
+
+        std::vector<llama_seq_id> strm; // [ns]
+        std::vector<idx_vec_t>    idxs; // [ns]
+
+        uint32_t head() const {
+            GGML_ASSERT(idxs.size() == 1);
+            GGML_ASSERT(!idxs[0].empty());
+
+            return idxs[0][0];
+        }
+
+        void resize(size_t n) {
+            strm.resize(n);
+            idxs.resize(n);
+        }
+
+        size_t size() const {
+            GGML_ASSERT(idxs.size() == strm.size());
+            GGML_ASSERT(!idxs.empty());
+
+            return idxs[0].size();
+        }
+
+        size_t n_stream() const {
+            return strm.size();
+        }
+
+        bool empty() const {
+            return idxs.empty();
+        }
+
+        void clear() {
+            idxs.clear();
+        }
+    };
+
+    using slot_info_vec_t = std::vector<slot_info>;
+
+    llama_kv_cache(
+            const llama_model &  model,
+              layer_filter_cb && filter,
+                    ggml_type    type_k,
+                    ggml_type    type_v,
+                         bool    v_trans,
+                         bool    offload,
+                         bool    unified,
+                     uint32_t    kv_size,
+                     uint32_t    n_seq_max,
+                     uint32_t    n_pad,
+                     uint32_t    n_swa,
+               llama_swa_type    swa_type);
+
+    ~llama_kv_cache() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_context_ptr init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) override;
+
+    llama_memory_context_ptr init_full() override;
+
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
+
+    //
+    // llama_kv_cache specific API
+    //
+
+    uint32_t get_size()     const;
+    uint32_t get_n_stream() const;
+
+    bool get_has_shift() const;
+
+    //
+    // graph_build API
+    //
+
+    uint32_t get_n_kv() const;
+
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+
+    // store k_cur and v_cur in the cache based on the provided head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
+
+    //
+    // preparation API
+    //
+
+    // find places for the provided ubatches in the cache, returns the slot infos
+    // return empty vector on failure
+    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
+
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
+
+    // find a slot of kv cells that can hold the ubatch
+    // if cont == true, then the slot must be continuous
+    // return empty slot_info on failure
+    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
+
+    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
+    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
+
+    //
+    // input API
+    //
+
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+
+    void set_input_k_shift(ggml_tensor * dst) const;
+
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+private:
+    const llama_model & model;
+    const llama_hparams & hparams;
+
+    struct kv_layer {
+        // layer index in the model
+        // note: can be different from the layer index in the KV cache
+        uint32_t il;
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+    };
+
+    bool v_trans = true;  // the value tensor is transposed
+
+    const uint32_t n_seq_max = 1;
+    const uint32_t n_stream  = 1;
+
+    // required padding
+    const uint32_t n_pad = 1;
+
+    // SWA
+    const uint32_t n_swa = 0;
+
+    // env: LLAMA_KV_CACHE_DEBUG
+    int debug = 0;
+
+    // env: LLAMA_SET_ROWS (temporary)
+    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
+    bool supports_set_rows = true;
+
+    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    std::vector<uint32_t> v_heads;
+
+    std::vector<llama_kv_cells> v_cells;
+
+    // maps from a sequence id to a stream id
+    std::vector<uint32_t> seq_to_stream;
+
+    // pending stream copies that will be applied during the next update
+    stream_copy_info sc_info;
+
+    std::vector<kv_layer> layers;
+
+    // model layer id -> KV cache layer id
+    std::unordered_map<int32_t, int32_t> map_layer_ids;
+
+    // return non-empty vector if cells have been moved
+    defrag_info defrag_prepare(int32_t n_max_nodes) const;
+
+    size_t total_size() const;
+
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
+
+    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+
+    ggml_tensor * build_rope_shift(
+            const llama_cparams & cparams,
+                   ggml_context * ctx,
+                    ggml_tensor * cur,
+                    ggml_tensor * shift,
+                    ggml_tensor * factors,
+                          float   freq_base,
+                          float   freq_scale) const;
+
+    ggml_cgraph * build_graph_shift(
+               llm_graph_result * res,
+                  llama_context * lctx) const;
+
+    ggml_cgraph * build_graph_defrag(
+               llm_graph_result * res,
+                  llama_context * lctx,
+              const defrag_info & dinfo) const;
+
+    struct cell_ranges_t {
+        uint32_t strm;
+
+        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
+    };
+
+    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
+};
+
+class llama_kv_cache_context : public llama_memory_context_i {
+public:
+    // some shorthands
+    using slot_info_vec_t  = llama_kv_cache::slot_info_vec_t;
+    using defrag_info      = llama_kv_cache::defrag_info;
+    using stream_copy_info = llama_kv_cache::stream_copy_info;
+
+    // used for errors
+    llama_kv_cache_context(llama_memory_status status);
+
+    // used to create a full-cache context
+    llama_kv_cache_context(
+            llama_kv_cache * kv);
+
+    // used to create an update context
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
+            llama_context * lctx,
+            bool do_shift,
+            defrag_info dinfo,
+            stream_copy_info sc_info);
+
+    // used to create a batch procesing context from a batch
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
+            slot_info_vec_t sinfos,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_context();
+
+    //
+    // llama_memory_context_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_context specific API
+    //
+
+    uint32_t get_n_kv() const;
+
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
+
+    // store k_cur and v_cur in the cache based on the provided head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
+
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+    void set_input_k_shift   (ggml_tensor * dst) const;
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+private:
+    llama_memory_status status;
+
+    llama_kv_cache * kv;
+    llama_context * lctx;
+
+    //
+    // update context
+    //
+
+    bool do_shift = false;
+
+    defrag_info dinfo;
+
+    stream_copy_info sc_info;
+
+    //
+    // batch processing context
+    //
+
+    // the index of the cur ubatch to process
+    size_t i_cur = 0;
+
+    slot_info_vec_t sinfos;
+
+    std::vector<llama_ubatch> ubatches;
+
+    //
+    // data needed for building the compute graph for the current ubatch:
+    //
+
+    // a heuristic, to avoid attending the full cache if it is not yet utilized
+    // as the cache gets filled, the benefit from this heuristic disappears
+    int32_t n_kv;
+};
index 0d0dd316fd0415f0952952750f510a3361f8f9ea..2651e30331fd692a38b1caab1ff3dabd837f36a0 100644 (file)
@@ -11,7 +11,7 @@
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
-class llama_kv_cells_unified {
+class llama_kv_cells {
 public:
     void reset() {
         for (uint32_t i = 0; i < pos.size(); ++i) {
@@ -97,10 +97,10 @@ public:
     }
 
     // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
-    llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+    llama_kv_cells cp(uint32_t i, uint32_t n) const {
         assert(i + n <= pos.size());
 
-        llama_kv_cells_unified res;
+        llama_kv_cells res;
 
         res.resize(n);
 
@@ -117,8 +117,8 @@ public:
     }
 
     // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
-        llama_kv_cells_unified res;
+    llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
+        llama_kv_cells res;
 
         res.resize(idxs.size());
 
@@ -135,7 +135,7 @@ public:
     }
 
     // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
-    void set(uint32_t i, const llama_kv_cells_unified & other) {
+    void set(uint32_t i, const llama_kv_cells & other) {
         assert(i + other.pos.size() <= pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -165,7 +165,7 @@ public:
     }
 
     // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
+    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
         assert(idxs.size() == other.pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
index cbeeb21344ecede21046e05666b2b4863f1a4187..f8303dacbf8adf952e7a5b583f746b44cd9a58a2 100644 (file)
@@ -30,7 +30,7 @@ llama_memory_hybrid::llama_memory_hybrid(
       layer_filter_cb && filter_attn,
       layer_filter_cb && filter_recr) :
     hparams(model.hparams),
-    mem_attn(new llama_kv_cache_unified(
+    mem_attn(new llama_kv_cache(
         model,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
@@ -179,7 +179,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
     mem_recr->state_read(io, seq_id);
 }
 
-llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
+llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
     return mem_attn.get();
 }
 
@@ -210,7 +210,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
+    ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
     ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
@@ -248,8 +248,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
+const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
+    return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
 }
 
 const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
index acdbc26bfb624c1a7ac22ce55d4808c6e3398148..e9c64ee40aae42a8ca8edaa7c6af1d826b789b3b 100644 (file)
@@ -2,7 +2,7 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache-unified.h"
+#include "llama-kv-cache.h"
 #include "llama-memory.h"
 #include "llama-memory-recurrent.h"
 
@@ -13,7 +13,7 @@
 // llama_memory_hybrid
 //
 
-// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
+// utilizes instances of llama_memory_recurrent and llama_kv_cache to
 //   support models where each layer may be either attention-based or recurrent
 
 class llama_memory_hybrid : public llama_memory_i {
@@ -81,19 +81,19 @@ public:
     // llama_memory_hybrid specific API
     //
 
-    llama_kv_cache_unified * get_mem_attn() const;
+    llama_kv_cache * get_mem_attn() const;
     llama_memory_recurrent * get_mem_recr() const;
 
 private:
     const llama_hparams & hparams;
 
-    const std::unique_ptr<llama_kv_cache_unified> mem_attn;
+    const std::unique_ptr<llama_kv_cache> mem_attn;
     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 };
 
 class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 
     // init failure
     explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -125,7 +125,7 @@ public:
     // llama_memory_hybrid_context
     //
 
-    const llama_kv_cache_unified_context * get_attn() const;
+    const llama_kv_cache_context * get_attn() const;
     const llama_memory_recurrent_context * get_recr() const;
 
 private:
index 95c617b2c94bdb53480d7129a0f1474e7cfbbfea..c8e8623602f78ad683d24b7c144190976b2d6287 100644 (file)
@@ -12,7 +12,7 @@
 //
 
 // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
-//       see the implementation of llama_kv_cache_unified_context_i for an example how to do it
+//       see the implementation of llama_kv_cache_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
 
index 171d312cc99d91c3b0665322de6f8e59f2c29de0..42a7145c2f38703aa61acb8680c33c667034cdce 100644 (file)
@@ -36,8 +36,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
 
 // the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
-//   - llama_kv_cache_unified_context
-//   - llama_kv_cache_unified_iswa_context
+//   - llama_kv_cache_context
+//   - llama_kv_cache_iswa_context
 //   ...
 //
 // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
@@ -109,8 +109,3 @@ struct llama_memory_i {
 };
 
 using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
-
-// TODO: temporary until the llama_kv_cache is removed from the public API
-struct llama_kv_cache : public llama_memory_i {
-    virtual ~llama_kv_cache() = default;
-};
index 431102edea18e8529b44806d53a4fb71a6fe73eb..cbb7bc875831d63d1f66eca2fc22ac666bd30614 100644 (file)
@@ -6,8 +6,8 @@
 #include "llama-cparams.h"
 #include "llama-model-loader.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -5986,7 +5986,7 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6146,7 +6146,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
         ggml_tensor * inp_attn_scale = nullptr;
         inp_attn_scale = build_inp_attn_scale();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6325,7 +6325,7 @@ struct llm_build_deci : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6481,7 +6481,7 @@ struct llm_build_baichuan : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6603,7 +6603,7 @@ struct llm_build_xverse : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6717,7 +6717,7 @@ struct llm_build_falcon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6841,7 +6841,7 @@ struct llm_build_grok : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7001,7 +7001,7 @@ struct llm_build_dbrx : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7125,7 +7125,7 @@ struct llm_build_starcoder : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -7230,7 +7230,7 @@ struct llm_build_refact : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7632,7 +7632,7 @@ struct llm_build_bloom : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         inpL = build_norm(inpL,
                 model.tok_norm,
@@ -7739,7 +7739,7 @@ struct llm_build_mpt : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         if (model.pos_embd) {
             // inp_pos - contains the positions
@@ -7889,7 +7889,7 @@ struct llm_build_stablelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8041,7 +8041,7 @@ struct llm_build_qwen : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8156,7 +8156,7 @@ struct llm_build_qwen2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8481,7 +8481,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         int sections[4];
         std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -8602,7 +8602,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8761,7 +8761,7 @@ struct llm_build_qwen3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8882,7 +8882,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9012,7 +9012,7 @@ struct llm_build_phi2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9141,13 +9141,13 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9299,7 +9299,7 @@ struct llm_build_plamo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9415,7 +9415,7 @@ struct llm_build_gpt2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -9525,7 +9525,7 @@ struct llm_build_codeshell : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9638,7 +9638,7 @@ struct llm_build_orion : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9765,7 +9765,7 @@ struct llm_build_internlm2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9901,7 +9901,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10096,7 +10096,7 @@ struct llm_build_gemma : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10212,7 +10212,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10346,7 +10346,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10497,7 +10497,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
         ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
@@ -10904,7 +10904,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11473,7 +11473,7 @@ struct llm_build_command_r : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11620,7 +11620,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11755,7 +11755,7 @@ struct llm_build_olmo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11883,7 +11883,7 @@ struct llm_build_olmo2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12012,7 +12012,7 @@ struct llm_build_olmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12138,7 +12138,7 @@ struct llm_build_openelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12269,7 +12269,7 @@ struct llm_build_gptneox : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12415,7 +12415,7 @@ struct llm_build_arctic : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12553,7 +12553,7 @@ struct llm_build_deepseek : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -12730,7 +12730,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12977,7 +12977,7 @@ struct llm_build_bitnet : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13241,7 +13241,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
         const int64_t n_outputs_enc = embd_enc->ne[1];
 
-        auto * inp_attn_self  = build_attn_inp_kv_unified();
+        auto * inp_attn_self  = build_attn_inp_kv();
         auto * inp_attn_cross = build_attn_inp_cross();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13406,7 +13406,7 @@ struct llm_build_jais : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13504,7 +13504,7 @@ struct llm_build_chatglm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13637,7 +13637,7 @@ struct llm_build_glm4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13787,7 +13787,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13947,7 +13947,7 @@ struct llm_build_nemotron : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -14076,7 +14076,7 @@ struct llm_build_exaone : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -14208,13 +14208,13 @@ struct llm_build_exaone4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -15097,7 +15097,7 @@ struct llm_build_granite : public llm_graph_context {
             inp_pos = build_inp_pos();
         }
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15148,12 +15148,12 @@ struct llm_build_granite : public llm_graph_context {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15367,12 +15367,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15529,7 +15529,7 @@ struct llm_build_chameleon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15860,7 +15860,7 @@ struct llm_build_plm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16025,7 +16025,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16174,7 +16174,7 @@ struct llm_build_dots1 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16324,7 +16324,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -16454,7 +16454,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16828,7 +16828,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
 
 private:
     ggml_tensor * build_plamo2_attn_layer(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * inp_pos,
             ggml_tensor * cur,
             const llama_model & model,
@@ -17061,7 +17061,7 @@ struct llm_build_arcee : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17196,7 +17196,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17357,7 +17357,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17495,7 +17495,7 @@ struct llm_build_smollm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17627,7 +17627,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -17809,10 +17809,10 @@ struct llm_build_lfm2 : public llm_graph_context {
         return cur;
     }
 
-    ggml_tensor * build_attn_block(ggml_tensor                     * cur,
-                                   ggml_tensor                     * inp_pos,
-                                   llm_graph_input_attn_kv_unified * inp_attn,
-                                   int                               il) const {
+    ggml_tensor * build_attn_block(ggml_tensor             * cur,
+                                   ggml_tensor             * inp_pos,
+                                   llm_graph_input_attn_kv * inp_attn,
+                                   int                     il) const {
         GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
         auto const n_embd_head = hparams.n_embd_head_v;
         auto const n_head_kv = hparams.n_head_kv(il);
@@ -17940,13 +17940,13 @@ struct llm_build_smallthinker : public llm_graph_context{
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -18076,7 +18076,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                             std::max((uint32_t) 1, cparams.n_seq_max),
                             cparams.n_seq_max);
                 } else if (llm_arch_is_hybrid(arch)) {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
 
@@ -18098,7 +18098,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                 } else {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     uint32_t n_ctx_per_stream = cparams.n_ctx;
 
@@ -18118,7 +18118,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
                         GGML_ASSERT(hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified_iswa(
+                        res = new llama_kv_cache_iswa(
                                 *this,
                                 params.type_k,
                                 params.type_v,
@@ -18133,7 +18133,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                     } else {
                         GGML_ASSERT(!hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified(
+                        res = new llama_kv_cache(
                                 *this,
                                 nullptr,
                                 params.type_k,