]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cells : track min/max used cells and per-sequence positions (#13808)
authorGeorgi Gerganov <redacted>
Tue, 27 May 2025 10:49:41 +0000 (13:49 +0300)
committerGitHub <redacted>
Tue, 27 May 2025 10:49:41 +0000 (13:49 +0300)
* kv-cells : track min/max used cells and per-sequence positions

ggml-ci

* kv-cells : fix pos-modification updates for seq_pos

ggml-ci

* kv-cells : add comments

ggml-ci

src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-kv-cells.h

index ae2d2684f8cbad82d55d198455600f3fbaa81200..4a42d6ecdc4556f7528810afce958d2854d50f6f 100644 (file)
@@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
-    llama_pos result = std::numeric_limits<llama_pos>::max();
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (cells.seq_has(i, seq_id)) {
-            result = std::min(result, cells.pos_get(i));
-        }
-    }
-
-    if (result == std::numeric_limits<llama_pos>::max()) {
-        result = -1;
-    }
-
-    return result;
+    return cells.seq_pos_min(seq_id);
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
-    llama_pos result = -1;
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (cells.seq_has(i, seq_id)) {
-            result = std::max(result, cells.pos_get(i));
-        }
-    }
-
-    return result;
+    return cells.seq_pos_max(seq_id);
 }
 
 void llama_kv_cache_unified::restore() {
@@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
     // a heuristic, to avoid attending the full cache if it is not yet utilized
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
-    n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
+    n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
 
 #ifdef FIND_SLOT_DEBUG
     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
 bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     const uint32_t n_layer = layers.size();
 
-    const uint32_t n_kv   = cell_max();
+    const uint32_t n_kv   = cells.used_max_p1();
     const uint32_t n_used = cells.get_used();
 
     assert(n_used <= n_kv);
@@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     return true;
 }
 
-uint32_t llama_kv_cache_unified::cell_max() const {
-    for (uint32_t i = cells.size(); i > 0; --i) {
-        if (!cells.is_empty(i - 1)) {
-            return i;
-        }
-    }
-
-    return 0;
-}
-
 bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
     assert(p0 >= 0 && p1 >= 0);
 
index 86a96820e2420e43f55a439f3d4467290ea7dc35..ce6261e45a6e17e0052d8fa6ffaeff32b9b91641 100644 (file)
@@ -246,10 +246,6 @@ private:
     // return true if cells have been moved
     bool defrag_prepare(int32_t n_max_nodes);
 
-    // find how many cells are currently in use
-    // TODO: optimize
-    uint32_t cell_max() const;
-
     size_t total_size() const;
 
     size_t size_k_bytes() const;
index 138545533ed2226b00c2bedaa31a1a8ff95900c1..dbbd03fcba2817834f8b82eab42f981dc3fec759 100644 (file)
@@ -6,6 +6,7 @@
 #include <bitset>
 #include <cassert>
 #include <vector>
+#include <set>
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
@@ -18,8 +19,13 @@ public:
             seq[i].reset();
         }
 
-        used      = 0;
         has_shift = false;
+
+        used.clear();
+
+        for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            seq_pos[s].clear();
+        }
     }
 
     void reset_shift() {
@@ -50,7 +56,25 @@ public:
     }
 
     uint32_t get_used() const {
-        return used;
+        return used.size();
+    }
+
+    // the index of the first cell that is used
+    // return 0 if no cells are used
+    uint32_t used_min() const {
+        return used.empty() ? 0 : *used.begin();
+    }
+
+    // the index of the last cell that is used + 1
+    // return 0 if no cells are used
+    uint32_t used_max_p1() const {
+#if 0
+        if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
+        if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
+        if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
+#endif
+
+        return used.empty() ? 0 : *used.rbegin() + 1;
     }
 
     bool get_has_shift() const {
@@ -69,6 +93,9 @@ public:
         pos  [isrc] = -1;
         shift[isrc] =  0;
         seq  [isrc].reset();
+
+        used.erase (isrc);
+        used.insert(idst);
     }
 
     // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@@ -95,16 +122,24 @@ public:
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
             if (pos[i + j] == -1 && other.pos[j] != -1) {
-                used++;
+                used.insert(i + j);
             }
 
             if (pos[i + j] != -1 && other.pos[j] == -1) {
-                used--;
+                used.erase(i + j);
+            }
+
+            if (pos[i + j] != -1) {
+                seq_pos_rm(i + j);
             }
 
             pos[i + j] = other.pos[j];
             seq[i + j] = other.seq[j];
 
+            if (pos[i + j] != -1) {
+                seq_pos_add(i + j);
+            }
+
             assert(shift[i + j] == 0);
         }
     }
@@ -118,11 +153,12 @@ public:
         assert(seq_id >= 0);
 
         seq[i].reset(seq_id);
+        seq_pos[seq_id].erase(pos[i]);
 
         if (seq[i].none()) {
             pos[i] = -1;
 
-            used--;
+            used.erase(i);
 
             return true;
         }
@@ -135,17 +171,22 @@ public:
         assert(i < pos.size());
 
         if (seq[i].test(seq_id)) {
+            seq_pos_rm(i);
             seq[i].reset();
+
             seq[i].set(seq_id);
+            seq_pos[seq_id].insert(pos[i]);
 
             return false;
         }
 
         if (seq[i].any()) {
+            seq_pos_rm(i);
             seq[i].reset();
+
             pos[i] = -1;
 
-            used--;
+            used.erase(i);
 
             return true;
         }
@@ -169,6 +210,33 @@ public:
         assert(!seq[i].test(seq_id));
 
         seq[i].set(seq_id);
+        seq_pos[seq_id].insert(pos[i]);
+    }
+
+    // the minimum position of sequence seq_id currently present in any of the cells
+    // return -1 if the sequence is not present
+    llama_pos seq_pos_min(llama_seq_id seq_id) const {
+        assert(seq_id >= 0);
+        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+
+        if (seq_pos[seq_id].empty()) {
+            return -1;
+        }
+
+        return *seq_pos[seq_id].begin();
+    }
+
+    // the maximum position of sequence seq_id currently present in any of the cells
+    // return -1 if the sequence is not present
+    llama_pos seq_pos_max(llama_seq_id seq_id) const {
+        assert(seq_id >= 0);
+        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+
+        if (seq_pos[seq_id].empty()) {
+            return -1;
+        }
+
+        return *seq_pos[seq_id].rbegin();
     }
 
     // note: call only if the cell is not empty
@@ -202,7 +270,8 @@ public:
         assert(pos[i] == -1);
 
         pos[i] = p;
-        used++;
+
+        used.insert(i);
     }
 
     // pos[i] = pos[i] + d
@@ -212,16 +281,22 @@ public:
         assert(i < pos.size());
         assert(pos[i] != -1);
 
+        seq_pos_rm(i);
+
         pos[i]   += d;
         shift[i] += d;
 
+        seq_pos_add(i);
+
         has_shift = true;
 
         if (pos[i] < 0) {
-            pos[i] = -1;
+            seq_pos_rm(i);
+
             seq[i].reset();
+            pos[i] = -1;
 
-            used--;
+            used.erase(i);
 
             return true;
         }
@@ -238,17 +313,22 @@ public:
 
         const llama_pos p_old = pos[i];
 
+        seq_pos_rm(i);
+
         pos[i]   /= d;
         shift[i] += p_old - pos[i];
 
+        seq_pos_add(i);
+
         has_shift = true;
     }
 
 private:
-    uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
-
     bool has_shift = false;
 
+    // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
+    std::set<uint32_t> used;
+
     std::vector<llama_pos> pos;
 
     // this array accumulates any applied shifts to the pos array since the last reset_shift() call
@@ -268,6 +348,32 @@ private:
     //
     std::vector<llama_pos> shift;
 
-    std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
-};
+    using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
+
+    // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
+    std::vector<bits_t> seq;
+
+    // the set seq_pos[s] tells us which positions are currently present for sequence s
+    // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
+    std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
+
+    // helper functions for updating `seq_pos`, once cell at a time:
+
+    // remove cell i
+    void seq_pos_rm(uint32_t i) {
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (seq[i].test(s)) {
+                seq_pos[s].erase(pos[i]);
+            }
+        }
+    }
 
+    // add cell i
+    void seq_pos_add(uint32_t i) {
+        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (seq[i].test(s)) {
+                seq_pos[s].insert(pos[i]);
+            }
+        }
+    }
+};