]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv cache slot search improvements (#3493)
authorKerfuffle <redacted>
Fri, 6 Oct 2023 16:10:13 +0000 (10:10 -0600)
committerGitHub <redacted>
Fri, 6 Oct 2023 16:10:13 +0000 (10:10 -0600)
* kv cache slot search improvements

* Use n_ctx in kv find slot for consistency

* Ensure kv cache head points to a valid slot in llama_decode internal

* Add some comments to prevent dumb people (like me) from getting confused.

llama.cpp

index 1a7d37b8dec47471f6f57983015437e89a6e887c..79ea2b235602eb167be70a61e78bca62a7a592f2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1082,6 +1082,9 @@ struct llama_kv_cell {
 struct llama_kv_cache {
     bool has_shift = false;
 
+    // Note: The value of head isn't only used to optimize searching
+    // for a free KV slot. llama_decode_internal also uses it, so it
+    // cannot be freely changed after a slot has been allocated.
     uint32_t head = 0;
     uint32_t size = 0;
 
@@ -1339,6 +1342,8 @@ static bool llama_kv_cache_init(
 
 // find an empty slot of size "n_tokens" in the cache
 // updates the cache head
+// Note: On success, it's important that cache.head points
+// to the first cell of the slot.
 static bool llama_kv_cache_find_slot(
            struct llama_kv_cache & cache,
         const struct llama_batch & batch) {
@@ -1354,8 +1359,8 @@ static bool llama_kv_cache_find_slot(
 
     while (true) {
         if (cache.head + n_tokens > n_ctx) {
+            n_tested += n_ctx - cache.head;
             cache.head = 0;
-            n_tested   += n_ctx - cache.head;
             continue;
         }
 
@@ -1406,6 +1411,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
         cache.cells[i].pos = -1;
         cache.cells[i].seq_id.clear();
     }
+
+    // Searching for a free slot can start here since we know it will be empty.
+    cache.head = uint32_t(c0);
 }
 
 static void llama_kv_cache_seq_rm(
@@ -1413,6 +1421,8 @@ static void llama_kv_cache_seq_rm(
                  llama_seq_id   seq_id,
                     llama_pos   p0,
                     llama_pos   p1) {
+    uint32_t new_head = cache.size;
+
     if (p0 < 0) p0 = 0;
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
@@ -1421,9 +1431,13 @@ static void llama_kv_cache_seq_rm(
             cache.cells[i].seq_id.erase(seq_id);
             if (cache.cells[i].seq_id.empty()) {
                 cache.cells[i].pos = -1;
+                if (new_head == cache.size) new_head = i;
             }
         }
     }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cache.size) cache.head = new_head;
 }
 
 static void llama_kv_cache_seq_cp(
@@ -1435,6 +1449,8 @@ static void llama_kv_cache_seq_cp(
     if (p0 < 0) p0 = 0;
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
+    cache.head = 0;
+
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
             cache.cells[i].seq_id.insert(seq_id_dst);
@@ -1443,12 +1459,18 @@ static void llama_kv_cache_seq_cp(
 }
 
 static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
+    uint32_t new_head = cache.size;
+
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (!cache.cells[i].has_seq_id(seq_id)) {
             cache.cells[i].pos = -1;
             cache.cells[i].seq_id.clear();
+            if (new_head == cache.size) new_head = i;
         }
     }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cache.size) cache.head = new_head;
 }
 
 static void llama_kv_cache_seq_shift(
@@ -1457,6 +1479,8 @@ static void llama_kv_cache_seq_shift(
                     llama_pos   p0,
                     llama_pos   p1,
                     llama_pos   delta) {
+    uint32_t new_head = cache.size;
+
     if (p0 < 0) p0 = 0;
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
@@ -1466,12 +1490,17 @@ static void llama_kv_cache_seq_shift(
             if (cache.cells[i].pos < 0) {
                 cache.cells[i].pos = -1;
                 cache.cells[i].seq_id.clear();
+                if (new_head == cache.size) new_head = i;
             } else {
                 cache.has_shift = true;
                 cache.cells[i].delta = delta;
             }
         }
     }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    // Otherwise we just start the next search from the beginning.
+    cache.head = new_head != cache.size ? new_head : 0;
 }
 
 //
@@ -4492,10 +4521,6 @@ static int llama_decode_internal(
         batch.seq_id = seq_id.data();
     }
 
-    // we always start to search for a free slot from the start of the cache
-    // TODO: better strategies can be implemented
-    kv_self.head = 0;
-
     if (!llama_kv_cache_find_slot(kv_self, batch)) {
         return 1;
     }
@@ -4581,8 +4606,12 @@ static int llama_decode_internal(
 #endif
 
     // update the kv ring buffer
-    lctx.kv_self.head      += n_tokens;
     lctx.kv_self.has_shift  = false;
+    lctx.kv_self.head      += n_tokens;
+    // Ensure kv cache head points to a valid index.
+    if (lctx.kv_self.head >= lctx.kv_self.size) {
+        lctx.kv_self.head = 0;
+    }
 
 #ifdef GGML_PERF
     // print timing information per ggml operation (for debugging purposes)