]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : use ggml_set_rows (#14285)
authorGeorgi Gerganov <redacted>
Thu, 3 Jul 2025 07:53:35 +0000 (10:53 +0300)
committerGitHub <redacted>
Thu, 3 Jul 2025 07:53:35 +0000 (10:53 +0300)
* kv-cache : use ggml_set_rows

ggml-ci

* graph : separate k and v indices

ggml-ci

* cont : remove redundant ifs

ggml-ci

* kv-cache : improve find_slot impl

* kv-cache : bounds-check when accessing slot_info indices

* kv-cache : add comments

ggml-ci

* ggml : add TODOs for adding GGML_OP_SET_ROWS support in the backends

ggml-ci

13 files changed:
ggml/src/ggml-cann/ggml-cann.cpp
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-kv-cells.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h

index 8a3d2026ed898c459b78d4d9afb6852a7ef4203a..eae575cc040cdfd19f4575ec93d4efd533f99d11 100755 (executable)
@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                     return false;
             }
         } break;
+        case GGML_OP_SET_ROWS:
+            {
+                // TODO: add support
+                // ref: https://github.com/ggml-org/llama.cpp/pull/14274
+                return false;
+            } break;
         case GGML_OP_CPY: {
             ggml_tensor *src = op->src[0];
             if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
index 71829c05a5bf626e85fc72848c7fed8534aab5aa..9436e6ea9a08de7801b60fe4e4291dc3d931048b 100644 (file)
@@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 default:
                     return false;
             }
+        case GGML_OP_SET_ROWS:
+            {
+                // TODO: add support
+                // ref: https://github.com/ggml-org/llama.cpp/pull/14274
+                return false;
+            } break;
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
index 1d41d7a4be8206182bca3c3bbb3f86d48badf644..ee4c8f7b2fa5cb9a1470519dd820cf2ba42b3e0c 100644 (file)
@@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                         return false;
                 }
             }
+        case GGML_OP_SET_ROWS:
+            {
+                // TODO: add support
+                // ref: https://github.com/ggml-org/llama.cpp/pull/14274
+                return false;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
index ccb0d47b6759b696a4bb902d194afc7a76d56dc7..8114cb42046e47bf201d24e3154b978eb7e7010e 100644 (file)
@@ -10339,6 +10339,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                         return false;
                 }
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                // TODO: add support
+                // ref: https://github.com/ggml-org/llama.cpp/pull/14274
+                return false;
+            } break;
         case GGML_OP_CONT:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
index f2fae6d1b71aa44a69d810fcde6da50d4df150a8..4443420132f2b5860a02ce897268af5f7377a012 100644 (file)
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
-    if (self_kq_mask) {
-        mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
-    }
+    mctx->set_input_k_idxs(self_k_idxs, ubatch);
+    mctx->set_input_v_idxs(self_v_idxs, ubatch);
+
+    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
-    if (self_kq_mask) {
-        mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
-    }
+    mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
+    mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
-    if (self_kq_mask_swa) {
-        mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
-    }
+    mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+
+    mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
+    mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
+
+    mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
 void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
-    if (self_kq_mask) {
-        mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
-    }
+    mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
+    mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
+
+    mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 
     const int64_t n_rs = mctx->get_recr()->get_n_rs();
 
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_one::set_input(const llama_ubatch *) {
+void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
     GGML_ASSERT(one && ggml_nelements(one) == 1);
     float f_one = 1.0f;
     ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 
         const auto n_kv = inp->mctx->get_attn()->get_n_kv();
 
+        inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
+        inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
+
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
         ggml_set_input(inp->self_kq_mask);
@@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
 
         const auto n_kv = mctx_cur->get_n_kv();
 
+        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
+        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
+
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask, "KQ_mask", -1);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
+        const auto & k_idxs = inp->get_k_idxs();
+        const auto & v_idxs = inp->get_v_idxs();
+
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
@@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // optionally store to KV cache
     if (k_cur) {
-        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
+
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
     }
 
     if (v_cur) {
-        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
+        const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
+
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
     }
 
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
+        const auto & k_idxs = inp->get_k_idxs();
+        const auto & v_idxs = inp->get_v_idxs();
+
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
@@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     {
         const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
+        inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
+        inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
+
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask, "KQ_mask", -1);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 
         const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
+        inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
+        inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
+
         inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
         ggml_set_input(inp->self_kq_mask_swa);
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
index db4e14805caa32ef32f58c4a0e6b428a6e0ba4f9..c8b74a14741e2677aeeadddecd2546f8ecd09e51 100644 (file)
@@ -249,8 +249,14 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
+    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
+
     ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 
+    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+
     ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
     ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
 
@@ -274,9 +280,19 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
+    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
+    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
+    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
+
     ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
+    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch]
+    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
+
     ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch]
     ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch]
     ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch]
@@ -319,8 +335,14 @@ public:
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
+    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
+    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
+
     ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 
+    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+
     ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
     ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
 
@@ -336,7 +358,7 @@ public:
     llm_graph_input_one() {}
     virtual ~llm_graph_input_one() = default;
 
-    void set_input(const llama_ubatch *) override;
+    void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * one = nullptr; // F32
 };
index d1f839b63aaf55fd61bd6a422f722ceca4adaac4..ee202cc710bd65849f08dde43385d127e9028117 100644 (file)
@@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        auto heads_base = kv_base->prepare(ubatches);
-        if (heads_base.empty()) {
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
             break;
         }
 
-        auto heads_swa = kv_swa->prepare(ubatches);
-        if (heads_swa.empty()) {
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
             break;
         }
 
-        assert(heads_base.size() == heads_swa.size());
+        assert(sinfos_base.size() == sinfos_swa.size());
 
         return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
     } while (false);
 
     // if it fails, try equal split
@@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        auto heads_base = kv_base->prepare(ubatches);
-        if (heads_base.empty()) {
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
             break;
         }
 
-        auto heads_swa = kv_swa->prepare(ubatches);
-        if (heads_swa.empty()) {
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
             break;
         }
 
-        assert(heads_base.size() == heads_swa.size());
+        assert(sinfos_base.size() == sinfos_swa.size());
 
         return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
     } while (false);
 
     // TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
 
 llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv,
-        std::vector<uint32_t> heads_base,
-        std::vector<uint32_t> heads_swa,
+        slot_info_vec_t sinfos_base,
+        slot_info_vec_t sinfos_swa,
         std::vector<llama_ubatch> ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
-    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
+    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
     status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
index 46c1ed614f2f0166960d86035396d26a76520420..23205d826b23b231bb2026f082debe064f092500 100644 (file)
@@ -74,6 +74,8 @@ private:
 
 class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
 public:
+    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+
     // used for errors
     llama_kv_cache_unified_iswa_context(llama_memory_status status);
 
@@ -90,8 +92,8 @@ public:
     // used to create a batch processing context from a batch
     llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv,
-            std::vector<uint32_t> heads_base,
-            std::vector<uint32_t> heads_swa,
+            slot_info_vec_t sinfos_base,
+            slot_info_vec_t sinfos_swa,
             std::vector<llama_ubatch> ubatches);
 
     virtual ~llama_kv_cache_unified_iswa_context();
index 7f7b162ffd7cefd524ab9676110d0c61051574ef..ff22079851b2a44964671b9638f9b3f467bc3ca5 100644 (file)
@@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 
     const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
     debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
+
+    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
+    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+
+    if (!supports_set_rows) {
+        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
+    }
 }
 
 void llama_kv_cache_unified::clear(bool data) {
@@ -353,13 +360,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        auto heads = prepare(ubatches);
-        if (heads.empty()) {
+        auto sinfos = prepare(ubatches);
+        if (sinfos.empty()) {
             break;
         }
 
         return std::make_unique<llama_kv_cache_unified_context>(
-                this, std::move(heads), std::move(ubatches));
+                this, std::move(sinfos), std::move(ubatches));
     } while (false);
 
     return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -402,12 +409,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
     return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
 }
 
-llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    llama_kv_cache_unified::ubatch_heads res;
+llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache_unified::slot_info_vec_t res;
 
     struct state {
         uint32_t head_old; // old position of the head, before placing the ubatch
-        uint32_t head_new; // new position of the head, after placing the ubatch
+
+        slot_info sinfo; // slot info for the ubatch
 
         llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
     };
@@ -418,26 +426,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
     bool success = true;
 
     for (const auto & ubatch : ubatches) {
+        // non-continuous slots require support for ggml_set_rows()
+        const bool cont = supports_set_rows ? false : true;
+
         // only find a suitable slot for the ubatch. don't modify the cells yet
-        const int32_t head_new = find_slot(ubatch);
-        if (head_new < 0) {
+        const auto sinfo_new = find_slot(ubatch, cont);
+        if (sinfo_new.empty()) {
             success = false;
             break;
         }
 
         // remeber the position that we found
-        res.push_back(head_new);
+        res.push_back(sinfo_new);
 
         // store the old state of the cells in the recovery stack
-        states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
+        states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
 
         // now emplace the ubatch
-        apply_ubatch(head_new, ubatch);
+        apply_ubatch(sinfo_new, ubatch);
     }
 
     // iterate backwards and restore the cells to their original state
     for (auto it = states.rbegin(); it != states.rend(); ++it) {
-        cells.set(it->head_new, it->cells);
+        cells.set(it->sinfo.idxs, it->cells);
         head = it->head_old;
     }
 
@@ -539,7 +550,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
     return updated;
 }
 
-int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
+llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
     uint32_t head_cur = this->head;
@@ -552,7 +563,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 
     if (n_tokens > cells.size()) {
         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-        return -1;
+        return { };
     }
 
     if (debug > 0) {
@@ -615,15 +626,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 
     uint32_t n_tested = 0;
 
+    // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
+    // for non-continuous slots, we test the tokens one by one
+    const uint32_t n_test = cont ? n_tokens : 1;
+
+    slot_info res;
+
+    auto & idxs = res.idxs;
+
+    idxs.reserve(n_tokens);
+
     while (true) {
-        if (head_cur + n_tokens > cells.size()) {
+        if (head_cur + n_test > cells.size()) {
             n_tested += cells.size() - head_cur;
             head_cur = 0;
             continue;
         }
 
-        bool found = true;
-        for (uint32_t i = 0; i < n_tokens; i++) {
+        for (uint32_t i = 0; i < n_test; i++) {
+            const auto idx = head_cur;
+
             //const llama_pos    pos    = ubatch.pos[i];
             //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
@@ -633,19 +655,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
             //    - mask SWA, using current max pos for that sequence in the cache
             //                always insert in the cell with minimum pos
-            bool can_use = cells.is_empty(head_cur + i);
+            bool can_use = cells.is_empty(idx);
 
-            if (!can_use && cells.seq_count(head_cur + i) == 1) {
-                const llama_pos pos_cell = cells.pos_get(head_cur + i);
+            if (!can_use && cells.seq_count(idx) == 1) {
+                const llama_pos pos_cell = cells.pos_get(idx);
 
                 // (disabled) causal mask
                 // note: it's better to purge any "future" tokens beforehand
-                //if (cells.seq_has(head_cur + i, seq_id)) {
+                //if (cells.seq_has(idx, seq_id)) {
                 //    can_use = pos_cell >= pos;
                 //}
 
                 if (!can_use) {
-                    const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
+                    const llama_seq_id seq_id_cell = cells.seq_get(idx);
 
                     // SWA mask
                     if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -654,28 +676,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
                 }
             }
 
-            if (!can_use) {
-                found = false;
-                head_cur += i + 1;
-                n_tested += i + 1;
+            head_cur++;
+            n_tested++;
+
+            if (can_use) {
+                idxs.push_back(idx);
+            } else {
                 break;
             }
         }
 
-        if (found) {
+        if (idxs.size() == n_tokens) {
             break;
         }
 
+        if (cont) {
+            idxs.clear();
+        }
+
         if (n_tested >= cells.size()) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return -1;
+            return { };
         }
     }
 
-    return head_cur;
+    // we didn't find a suitable slot - return empty result
+    if (idxs.size() < n_tokens) {
+        res.clear();
+    }
+
+    return res;
 }
 
-void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
+void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -683,22 +716,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
         seq_pos_max_rm[s] = -1;
     }
 
+    assert(ubatch.n_tokens == sinfo.idxs.size());
+
     for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
-        if (!cells.is_empty(head_cur + i)) {
-            assert(cells.seq_count(head_cur + i) == 1);
+        const auto idx = sinfo.idxs.at(i);
 
-            const llama_seq_id seq_id = cells.seq_get(head_cur + i);
-            const llama_pos    pos    = cells.pos_get(head_cur + i);
+        if (!cells.is_empty(idx)) {
+            assert(cells.seq_count(idx) == 1);
+
+            const llama_seq_id seq_id = cells.seq_get(idx);
+            const llama_pos    pos    = cells.pos_get(idx);
 
             seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 
-            cells.rm(head_cur + i);
+            cells.rm(idx);
         }
 
-        cells.pos_set(head_cur + i, ubatch.pos[i]);
+        cells.pos_set(idx, ubatch.pos[i]);
 
         for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
-            cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
+            cells.seq_add(idx, ubatch.seq_id[i][s]);
         }
     }
 
@@ -719,7 +756,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
     }
 
     // move the head at the end of the slot
-    head = head_cur + ubatch.n_tokens;
+    head = sinfo.idxs.back() + 1;
 }
 
 bool llama_kv_cache_unified::get_can_shift() const {
@@ -772,47 +809,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
             0);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
+ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
 
+    const int64_t n_embd_k_gqa = k->ne[0];
     const int64_t n_tokens = k_cur->ne[2];
 
+    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
+
+    if (k_idxs && supports_set_rows) {
+        return ggml_set_rows(ctx, k, k_cur, k_idxs);
+    }
+
+    // TODO: fallback to old ggml_cpy() method for backwards compatibility
+    //       will be removed when ggml_set_rows() is adopted by all backends
+
     ggml_tensor * k_view = ggml_view_1d(ctx, k,
-            n_tokens*hparams.n_embd_k_gqa(il),
-            ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
+            n_tokens*n_embd_k_gqa,
+            ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
 
     return ggml_cpy(ctx, k_cur, k_view);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
+ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
 
+    const int64_t n_embd_v_gqa = v->ne[0];
     const int64_t n_tokens = v_cur->ne[2];
 
-    v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
+    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
+
+    if (v_idxs && supports_set_rows) {
+        if (!v_trans) {
+            return ggml_set_rows(ctx, v, v_cur, v_idxs);
+        }
+
+        // the row becomes a single element
+        ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
+
+        // note: the V cache is transposed when not using flash attention
+        v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
+
+        // note: we can be more explicit here at the cost of extra cont
+        //       however, above we take advantage that a row of single element is always continuous regardless of the row stride
+        //v_cur = ggml_transpose(ctx, v_cur);
+        //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
+
+        // we broadcast the KV indices n_embd_v_gqa times
+        // v      [1,        n_kv,     n_embd_v_gqa]
+        // v_cur  [1,        n_tokens, n_embd_v_gqa]
+        // v_idxs [n_tokens, 1,        1]
+        return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
+    }
+
+    // TODO: fallback to old ggml_cpy() method for backwards compatibility
+    //       will be removed when ggml_set_rows() is adopted by all backends
 
     ggml_tensor * v_view = nullptr;
 
     if (!v_trans) {
         v_view = ggml_view_1d(ctx, v,
-                n_tokens*hparams.n_embd_v_gqa(il),
-                ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
+                n_tokens*n_embd_v_gqa,
+                ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
     } else {
-        // note: the V cache is transposed when not using flash attention
-        v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
-                (v->ne[1])*ggml_element_size(v),
-                (head_cur)*ggml_element_size(v));
-
         v_cur = ggml_transpose(ctx, v_cur);
+
+        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
+                (v->ne[1]    )*ggml_element_size(v),
+                (sinfo.head())*ggml_element_size(v));
     }
 
     return ggml_cpy(ctx, v_cur, v_view);
 }
 
+ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+
+    ggml_set_input(k_idxs);
+
+    return k_idxs;
+}
+
+ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+
+    ggml_set_input(v_idxs);
+
+    return v_idxs;
+}
+
+void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    if (!supports_set_rows) {
+        return;
+    }
+
+    const uint32_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    for (int64_t i = 0; i < n_tokens; ++i) {
+        data[i] = sinfo.idxs.at(i);
+    }
+}
+
+void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    if (!supports_set_rows) {
+        return;
+    }
+
+    const uint32_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    for (int64_t i = 0; i < n_tokens; ++i) {
+        data[i] = sinfo.idxs.at(i);
+    }
+}
+
 void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     const uint32_t n_tokens = ubatch->n_tokens;
 
@@ -1552,13 +1675,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             ubatch.seq_id[i]   = &dest_seq_id;
         }
 
-        const auto head_cur = find_slot(ubatch);
-        if (head_cur < 0) {
+        const auto sinfo = find_slot(ubatch, true);
+        if (sinfo.empty()) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
 
-        apply_ubatch(head_cur, ubatch);
+        apply_ubatch(sinfo, ubatch);
+
+        const auto head_cur = sinfo.head();
 
         // keep the head at the old position because we will read the KV data into it in state_read_data()
         head = head_cur;
@@ -1744,7 +1869,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
 llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
-    head = 0;
+
+    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
+    sinfos.resize(1);
+    sinfos[0].idxs.resize(1);
+    sinfos[0].idxs[0] = 0;
 }
 
 llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@@ -1759,8 +1888,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
 
 llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
-        llama_kv_cache_unified::ubatch_heads heads,
-        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
+        llama_kv_cache_unified::slot_info_vec_t sinfos,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
 }
 
 llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@@ -1768,7 +1897,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
 bool llama_kv_cache_unified_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    if (++i_next >= ubatches.size()) {
+    if (++i_cur >= ubatches.size()) {
         return false;
     }
 
@@ -1785,10 +1914,9 @@ bool llama_kv_cache_unified_context::apply() {
         return true;
     }
 
-    kv->apply_ubatch(heads[i_next], ubatches[i_next]);
+    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
 
     n_kv = kv->get_n_kv();
-    head = heads[i_next];
 
     return true;
 }
@@ -1800,7 +1928,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
 const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return ubatches[i_next];
+    return ubatches[i_cur];
 }
 
 uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@@ -1815,18 +1943,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
     return kv->get_v(ctx, il, n_kv);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
-    return kv->cpy_k(ctx, k_cur, il, head);
+ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
+    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
+    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_k_idxs(ctx, ubatch);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
-    return kv->cpy_v(ctx, v_cur, il, head);
+ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_v_idxs(ctx, ubatch);
 }
 
 void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
     kv->set_input_k_shift(dst);
 }
 
+void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
+void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
 void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     kv->set_input_kq_mask(dst, ubatch, causal_attn);
 }
index 4c53f1273ab88326cb21f6bc25fd999935a8b761..b8b0356e830c89d84483e6afc474f6cf7492e1f2 100644 (file)
@@ -24,8 +24,6 @@ public:
     // this callback is used to filter out layers that should not be included in the cache
     using layer_filter_cb = std::function<bool(int32_t il)>;
 
-    using ubatch_heads = std::vector<uint32_t>;
-
     struct defrag_info {
         bool empty() const {
             return ids.empty();
@@ -37,6 +35,32 @@ public:
         std::vector<uint32_t> ids;
     };
 
+    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
+    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
+    struct slot_info {
+        // data for ggml_set_rows
+        using idx_vec_t = std::vector<uint32_t>;
+
+        idx_vec_t idxs;
+
+        uint32_t head() const {
+            return idxs.at(0);
+        }
+
+        bool empty() const {
+            return idxs.empty();
+        }
+
+        void clear() {
+            idxs.clear();
+        }
+
+        // TODO: implement
+        //std::vector<idx_vec_t> seq_idxs;
+    };
+
+    using slot_info_vec_t = std::vector<slot_info>;
+
     llama_kv_cache_unified(
             const llama_model &  model,
               layer_filter_cb && filter,
@@ -102,30 +126,37 @@ public:
     ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
 
     //
     // preparation API
     //
 
-    // find places for the provided ubatches in the cache, returns the head locations
+    // find places for the provided ubatches in the cache, returns the slot infos
     // return empty vector on failure
-    ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
+    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
 
     bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
 
-    // return the cell position where we can insert the ubatch
-    // return -1 on failure to find a contiguous slot of kv cells
-    int32_t find_slot(const llama_ubatch & ubatch) const;
+    // find a slot of kv cells that can hold the ubatch
+    // if cont == true, then the slot must be continuous
+    // return empty slot_info on failure
+    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
 
-    // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
-    void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
+    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
+    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
 
     //
-    // set_input API
+    // input API
     //
 
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
     void set_input_k_shift   (ggml_tensor * dst) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -157,8 +188,13 @@ private:
     // SWA
     const uint32_t n_swa = 0;
 
+    // env: LLAMA_KV_CACHE_DEBUG
     int debug = 0;
 
+    // env: LLAMA_SET_ROWS (temporary)
+    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
+    int supports_set_rows = false;
+
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
     std::vector<ggml_context_ptr>        ctxs;
@@ -211,8 +247,8 @@ private:
 class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
-    using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
-    using defrag_info  = llama_kv_cache_unified::defrag_info;
+    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+    using defrag_info     = llama_kv_cache_unified::defrag_info;
 
     // used for errors
     llama_kv_cache_unified_context(llama_memory_status status);
@@ -231,7 +267,7 @@ public:
     // used to create a batch procesing context from a batch
     llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv,
-            ubatch_heads heads,
+            slot_info_vec_t sinfos,
             std::vector<llama_ubatch> ubatches);
 
     virtual ~llama_kv_cache_unified_context();
@@ -257,11 +293,16 @@ public:
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
+
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
 
-    void set_input_k_shift(ggml_tensor * dst) const;
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
+    void set_input_k_shift   (ggml_tensor * dst) const;
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
@@ -283,10 +324,10 @@ private:
     // batch processing context
     //
 
-    // the index of the next ubatch to process
-    size_t i_next = 0;
+    // the index of the cur ubatch to process
+    size_t i_cur = 0;
 
-    ubatch_heads heads;
+    slot_info_vec_t sinfos;
 
     std::vector<llama_ubatch> ubatches;
 
@@ -297,7 +338,4 @@ private:
     // a heuristic, to avoid attending the full cache if it is not yet utilized
     // as the cache gets filled, the benefit from this heuristic disappears
     int32_t n_kv;
-
-    // the beginning of the current slot in which the ubatch will be inserted
-    int32_t head;
 };
index c95d635948b5d61190cbaa5a128d339cac64910c..0d0dd316fd0415f0952952750f510a3361f8f9ea 100644 (file)
@@ -105,10 +105,30 @@ public:
         res.resize(n);
 
         for (uint32_t j = 0; j < n; ++j) {
-            res.pos[j] = pos[i + j];
-            res.seq[j] = seq[i + j];
+            const auto idx = i + j;
 
-            assert(shift[i + j] == 0);
+            res.pos[j] = pos[idx];
+            res.seq[j] = seq[idx];
+
+            assert(shift[idx] == 0);
+        }
+
+        return res;
+    }
+
+    // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
+    llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
+        llama_kv_cells_unified res;
+
+        res.resize(idxs.size());
+
+        for (uint32_t j = 0; j < idxs.size(); ++j) {
+            const auto idx = idxs[j];
+
+            res.pos[j] = pos[idx];
+            res.seq[j] = seq[idx];
+
+            assert(shift[idx] == 0);
         }
 
         return res;
@@ -119,26 +139,58 @@ public:
         assert(i + other.pos.size() <= pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
-            if (pos[i + j] == -1 && other.pos[j] != -1) {
+            const auto idx = i + j;
+
+            if (pos[idx] == -1 && other.pos[j] != -1) {
                 used.insert(i + j);
             }
 
-            if (pos[i + j] != -1 && other.pos[j] == -1) {
+            if (pos[idx] != -1 && other.pos[j] == -1) {
                 used.erase(i + j);
             }
 
-            if (pos[i + j] != -1) {
+            if (pos[idx] != -1) {
                 seq_pos_rm(i + j);
             }
 
-            pos[i + j] = other.pos[j];
-            seq[i + j] = other.seq[j];
+            pos[idx] = other.pos[j];
+            seq[idx] = other.seq[j];
 
-            if (pos[i + j] != -1) {
+            if (pos[idx] != -1) {
                 seq_pos_add(i + j);
             }
 
-            assert(shift[i + j] == 0);
+            assert(shift[idx] == 0);
+        }
+    }
+
+    // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
+    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
+        assert(idxs.size() == other.pos.size());
+
+        for (uint32_t j = 0; j < other.pos.size(); ++j) {
+            const auto idx = idxs[j];
+
+            if (pos[idx] == -1 && other.pos[j] != -1) {
+                used.insert(idx);
+            }
+
+            if (pos[idx] != -1 && other.pos[j] == -1) {
+                used.erase(idx);
+            }
+
+            if (pos[idx] != -1) {
+                seq_pos_rm(idx);
+            }
+
+            pos[idx] = other.pos[j];
+            seq[idx] = other.seq[j];
+
+            if (pos[idx] != -1) {
+                seq_pos_add(idx);
+            }
+
+            assert(shift[idx] == 0);
         }
     }
 
index 67cbf955482354a1097c5454107acc4afe4c8c36..03d974d852039be21960347eccd0941d785b1c15 100644 (file)
@@ -195,11 +195,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
 
 llama_memory_hybrid_context::llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
-            std::vector<uint32_t>   heads_attn,
+                  slot_info_vec_t   sinfos_attn,
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
+    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
     ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
index f0c2420e9a2df5bc04093c07e972a837b8ccd1aa..4ac318175785e50d410b32addf15e2674ef3a39b 100644 (file)
@@ -92,6 +92,8 @@ private:
 
 class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
+    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+
     // init failure
     explicit llama_memory_hybrid_context(llama_memory_status status);
 
@@ -107,7 +109,7 @@ public:
     // init success
     llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
-            std::vector<uint32_t>   heads_attn,
+                  slot_info_vec_t   sinfos_attn,
         std::vector<llama_ubatch>   ubatches);
 
     ~llama_memory_hybrid_context() = default;