]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : fix shift and defrag logic (#14081)
authorGeorgi Gerganov <redacted>
Mon, 9 Jun 2025 20:04:35 +0000 (23:04 +0300)
committerGitHub <redacted>
Mon, 9 Jun 2025 20:04:35 +0000 (23:04 +0300)
* kv-cache : fix shift

ggml-ci

* cont : reset shift[i]

ggml-ci

* cont : fix defrag erasing cells that didn't move

ggml-ci

src/llama-kv-cache-unified.cpp
src/llama-kv-cells.h

index 3a40463fd29cafcaad9cdd8093db442ff8fd9874..3566d5fd4d72bdf5a7e8a968f18197e7743815ae 100644 (file)
@@ -462,7 +462,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
             for (uint32_t i = 0; i < n_kv; ++i) {
                 assert(dinfo.ids[i] <= n_kv);
 
-                if (dinfo.ids[i] == n_kv) {
+                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
                     continue;
                 }
 
@@ -944,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    //GGML_ASSERT(kv_self->size == n_ctx);
-
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
     ggml_set_input(inp->k_shift);
 
     for (const auto & layer : layers) {
index 9e2c4d927699d72300ed013b17135a9b12c63bdb..acf30aebec69b071bc0b36223771e71ce0a5f9fd 100644 (file)
@@ -80,6 +80,9 @@ public:
         assert(isrc < pos.size());
         assert(idst < pos.size());
 
+        assert(pos[idst] == -1);
+        assert(pos[isrc] != -1);
+
         pos  [idst] = pos  [isrc];
         shift[idst] = shift[isrc];
         seq  [idst] = seq  [isrc];
@@ -144,9 +147,10 @@ public:
         assert(pos[i] != -1);
 
         seq_pos_rm(i);
+        seq[i].reset();
 
         pos[i] = -1;
-        seq[i].reset();
+        shift[i] = 0;
 
         used.erase(i);
     }
@@ -164,6 +168,7 @@ public:
 
         if (seq[i].none()) {
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
@@ -192,6 +197,7 @@ public:
             seq[i].reset();
 
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
@@ -317,21 +323,20 @@ public:
         pos[i]   += d;
         shift[i] += d;
 
-        seq_pos_add(i);
-
         has_shift = true;
 
         if (pos[i] < 0) {
-            seq_pos_rm(i);
-
             seq[i].reset();
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
             return true;
         }
 
+        seq_pos_add(i);
+
         return false;
     }