]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix kv shift bug (#3835)
authorGeorgi Gerganov <redacted>
Sun, 29 Oct 2023 16:32:51 +0000 (18:32 +0200)
committerGitHub <redacted>
Sun, 29 Oct 2023 16:32:51 +0000 (18:32 +0200)
ggml-ci

llama.cpp

index 1d1db8fc97fa388fabe75c8349c2b339684bea80..d8510a5cf01254f638cb6dda404eac6ccca7a68b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
 
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].pos += delta;
+            cache.has_shift = true;
+            cache.cells[i].pos   += delta;
+            cache.cells[i].delta += delta;
+
             if (cache.cells[i].pos < 0) {
                 cache.cells[i].pos = -1;
                 cache.cells[i].seq_id.clear();
                 if (new_head == cache.size) new_head = i;
-            } else {
-                cache.has_shift = true;
-                cache.cells[i].delta = delta;
             }
         }
     }
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
 #endif
 
     // update the kv ring buffer
-    lctx.kv_self.has_shift  = false;
-    lctx.kv_self.head      += n_tokens;
-    // Ensure kv cache head points to a valid index.
-    if (lctx.kv_self.head >= lctx.kv_self.size) {
-        lctx.kv_self.head = 0;
+    {
+        if (kv_self.has_shift) {
+            kv_self.has_shift = false;
+            for (uint32_t i = 0; i < kv_self.size; ++i) {
+                kv_self.cells[i].delta = 0;
+            }
+        }
+
+        kv_self.head += n_tokens;
+
+        // Ensure kv cache head points to a valid index.
+        if (kv_self.head >= kv_self.size) {
+            kv_self.head = 0;
+        }
     }
 
 #ifdef GGML_PERF