]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : rework kv_cell (#13706)
authorGeorgi Gerganov <redacted>
Sun, 25 May 2025 13:34:36 +0000 (16:34 +0300)
committerGitHub <redacted>
Sun, 25 May 2025 13:34:36 +0000 (16:34 +0300)
* kv-cache : rework kv_cell

ggml-ci

* kv-cells : use "shift" instead of "delta" consistently

ggml-ci

* llama : add llama_max_parallel_sequences()

ggml-ci

* kv-cells : update comments [no ci]

* context : fail upon construction if sequences exceed max value

ggml-ci

* kv-cells : get_pos() -> pos_get() + comments

ggml-ci

* kv-cells : fix tracking of "used" cells

ggml-ci

include/llama.h
src/llama-context.cpp
src/llama-cparams.cpp
src/llama-cparams.h
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-kv-cells.h [new file with mode: 0644]
src/llama-memory.h

index 52cd7a5a037ef3b4f21d322911bd6390c6ec211f..eafab7323d9bf9be18c2a1cdc885698272348840 100644 (file)
@@ -471,6 +471,7 @@ extern "C" {
     LLAMA_API int64_t llama_time_us(void);
 
     LLAMA_API size_t llama_max_devices(void);
+    LLAMA_API size_t llama_max_parallel_sequences(void);
 
     LLAMA_API bool llama_supports_mmap       (void);
     LLAMA_API bool llama_supports_mlock      (void);
index 85b4324b699e66292a681235a96f414f7d2c1dd3..98ecb7c8249ceaa484eb04ea4e12da120912867b 100644 (file)
@@ -25,7 +25,11 @@ llama_context::llama_context(
 
     const auto & hparams = model.hparams;
 
-    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
+    cparams.n_seq_max = std::max(1u, params.n_seq_max);
+    if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
+        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
+    }
+
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
     cparams.yarn_ext_factor  = params.yarn_ext_factor;
index 28369be365252724d929420f10f243bbfcd27937..f7b36590fe3e3f0a9dc9dfee9e7f749d2610aee7 100644 (file)
@@ -1 +1,5 @@
 #include "llama-cparams.h"
+
+size_t llama_max_parallel_sequences(void) {
+    return LLAMA_MAX_PARALLEL_SEQUENCES;
+}
index 246fa5777deea1f6d4b94581d9b07b258b9434a2..2871031ef09619bbce252126c1ddd13fb4681dcd 100644 (file)
@@ -4,6 +4,8 @@
 
 #include <cstdint>
 
+#define LLAMA_MAX_PARALLEL_SEQUENCES 64
+
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
     uint32_t n_batch;
index a2624d71589b5acbdd8695f6587b08354cf317c9..ae2d2684f8cbad82d55d198455600f3fbaa81200 100644 (file)
@@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     };
 
     head = 0;
-    size = kv_size;
-    used = 0;
 
     cells.resize(kv_size);
 
@@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 }
 
 void llama_kv_cache_unified::clear() {
-    for (uint32_t i = 0; i < size; ++i) {
-        cells[i].pos = -1;
-        cells[i].seq_id.clear();
-    }
+    cells.reset();
 
     head = 0;
-    used = 0;
 
     for (auto & buf : bufs) {
         ggml_backend_buffer_clear(buf.get(), 0);
@@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() {
 }
 
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
     if (p0 < 0) {
         p0 = 0;
@@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].pos >= p0 && cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cells[i].seq_id.clear();
-            } else if (cells[i].has_seq_id(seq_id)) {
-                cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-
-            if (cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cells[i].pos >= 0) {
-                    used--;
-                }
-
-                cells[i].pos = -1;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-                if (new_head == size) {
-                    new_head = i;
-                }
+        if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+            if (new_head == cells.size()) {
+                new_head = i;
             }
         }
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
+    if (new_head != cells.size() && new_head < head) {
         head = new_head;
     }
 
@@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // otherwise, this is the KV of a Transformer-like model
-    head = 0;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            cells[i].seq_id.insert(seq_id_dst);
+        if (cells.seq_has(i, seq_id_src)) {
+            cells.seq_add(i, seq_id_dst);
         }
     }
 }
 
 void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (!cells[i].has_seq_id(seq_id)) {
-            if (cells[i].pos >= 0) {
-                used--;
-            }
-
-            cells[i].pos = -1;
-            cells[i].seq_id.clear();
-
-            if (new_head == size){
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_keep(i, seq_id)) {
+            if (new_head == cells.size()) {
                 new_head = i;
             }
-        } else {
-            cells[i].seq_id.clear();
-            cells[i].seq_id.insert(seq_id);
         }
     }
 
     // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
+    if (new_head != cells.size() && new_head < head) {
         head = new_head;
     }
 }
 
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
+void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    if (shift == 0) {
         return;
     }
 
-    uint32_t new_head = size;
+    uint32_t new_head = cells.size();
 
     if (p0 < 0) {
         p0 = 0;
@@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // If there is no range then return early to avoid looping over the
+    // If there is no range then return early to avoid looping over all cells.
     if (p0 == p1) {
         return;
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            has_shift = true;
-
-            cells[i].pos   += delta;
-            cells[i].delta += delta;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-            if (cells[i].pos < 0) {
-                if (!cells[i].is_empty()) {
-                    used--;
-                }
-                cells[i].pos = -1;
-                cells[i].seq_id.clear();
-                if (new_head == size) {
+        if (cells.seq_has(i, seq_id)) {
+            if (cells.pos_add(i, shift)) {
+                if (new_head == cells.size()) {
                     new_head = i;
                 }
             }
@@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
 
     // If we freed up a slot, set head to it so searching can start there.
     // Otherwise we just start the next search from the beginning.
-    head = new_head != size ? new_head : 0;
+    head = new_head != cells.size() ? new_head : 0;
 }
 
 void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
-            has_shift = true;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
 
-            {
-                llama_pos p_old = cells[i].pos;
-                cells[i].pos   /= d;
-                cells[i].delta += cells[i].pos - p_old;
-            }
+        if (cells.seq_has(i, seq_id)) {
+            cells.pos_div(i, d);
         }
     }
 }
@@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
 llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
     llama_pos result = std::numeric_limits<llama_pos>::max();
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id)) {
-            result = std::min(result, cells[i].pos);
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_has(i, seq_id)) {
+            result = std::min(result, cells.pos_get(i));
         }
     }
 
@@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
 llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     llama_pos result = -1;
 
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cells[i].pos);
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_has(i, seq_id)) {
+            result = std::max(result, cells.pos_get(i));
         }
     }
 
@@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 }
 
 void llama_kv_cache_unified::restore() {
-    for (const auto & [id, cell] : recovery.cells) {
-        // TODO: move to new `struct kv_cells`
-        const bool is_empty0 = cells[id].is_empty();
-        const bool is_empty1 = cell.is_empty();
-
-        if (!is_empty0 && is_empty1) {
-            used--;
-        } else if (is_empty0 && !is_empty1) {
-            used++;
-        }
-
-        cells[id] = cell;
+    for (auto & state : recovery.states) {
+        cells.set(state.i, state.cells);
     }
 
     recovery.clear();
 }
 
 void llama_kv_cache_unified::commit() {
-    if (recovery.cells.empty()) {
+    if (recovery.states.empty()) {
         LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
                 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
         return;
@@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
     auto * sched = lctx.get_sched();
 
-    if (has_shift) {
+    if (cells.get_has_shift()) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
         }
@@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             need_reserve = true;
         }
 
-        {
-            has_shift = false;
-
-            for (uint32_t i = 0; i < size; ++i) {
-                cells[i].delta = 0;
-            }
-        }
+        cells.reset_shift();
     }
 
     if (do_defrag) {
@@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 void llama_kv_cache_unified::defrag_sched(float thold) {
     // - do not defrag small contexts (i.e. < 2048 tokens)
     // - count the padding towards the number of used tokens
-    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
+    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
 
     // queue defragmentation for next llama_kv_cache_update
     if (fragmentation > thold) {
@@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
 }
 
 void llama_kv_cache_unified::set_full() {
-    n = size;
+    n = cells.size();
 
     // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
     //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
@@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
 
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*ubatch.n_tokens) {
+    if (head > cells.get_used() + 2*ubatch.n_tokens) {
         head = 0;
     }
 
     // otherwise, one cell per token.
 
-    if (n_tokens > size) {
-        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
+    if (n_tokens > cells.size()) {
+        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
         return false;
     }
 
@@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
         std::string ss;
         if (n_swa > 0) {
             for (uint32_t i = 0; i < size; ++i) {
-                if (cells[i].pos == -1) {
+                if (cells.is_empty(i)) {
                     ss += '.';
                 } else {
-                    ss += std::to_string(*cells[i].seq_id.begin());
+                    ss += 'x';
                 }
                 if (i%256 == 255) {
                     ss += '\n';
@@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
     uint32_t n_tested = 0;
 
     while (true) {
-        if (head + n_tokens > size) {
-            n_tested += size - head;
+        if (head + n_tokens > cells.size()) {
+            n_tested += cells.size() - head;
             head = 0;
             continue;
         }
 
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            if (cells[head + i].pos >= 0) {
+            // TODO: improve to accept cells that are masked by the SWA
+            if (!cells.is_empty(head + i)) {
                 found = false;
                 head     += i + 1;
                 n_tested += i + 1;
@@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
             break;
         }
 
-        if (n_tested >= size) {
+        if (n_tested >= cells.size()) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
             return false;
         }
     }
 
-    for (uint32_t i = 0; i < n_tokens; ++i) {
-        // remember the original state
-        if (recovery.cells.find(head + i) == recovery.cells.end()) {
-            recovery.cells[head + i] = cells[head + i];
-        }
+    // store the old state of the cells in the recovery stack
+    recovery.states.push_back({head, cells.cp(head, n_tokens)});
 
-        cells[head + i].pos = ubatch.pos[i];
+    for (uint32_t i = 0; i < n_tokens; ++i) {
+        cells.pos_set(head + i, ubatch.pos[i]);
 
         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
-            cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
+            cells.seq_add(head + i, ubatch.seq_id[i][j]);
         }
     }
 
-    used += n_tokens;
-
     // a heuristic, to avoid attending the full cache if it is not yet utilized
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
-    n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
+    n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
 
 #ifdef FIND_SLOT_DEBUG
     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const {
 }
 
 uint32_t llama_kv_cache_unified::get_size() const {
-    return size;
+    return cells.size();
 }
 
 ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
@@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
 
     int n_attended = 0;
 
-    for (uint32_t i = 0; i < size; ++i) {
-        const llama_pos p0 = cells[i].pos;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.seq_has(i, seq_id)) {
+            continue;
+        }
+
+        const llama_pos p0 = cells.pos_get(i);
 
         if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
             n_attended++;
         }
 
         if (is_masked_swa(p0, pmax)) {
-            if (seq_id < 0) {
-                cells[i].seq_id.clear();
-            } else if (cells[i].has_seq_id(seq_id)) {
-                cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-
-            if (cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cells[i].pos >= 0) {
-                    used--;
-                }
-
-                cells[i].pos = -1;
-            }
+            cells.seq_rm(i, seq_id);
         }
     }
 
@@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
 
                 for (int i = 0; i < n_kv; ++i) {
-                    const llama_pos p0 = cells[i].pos;
+                    float f = 0.0f;
 
                     bool masked = false;
 
-                    // mask the token if not the same sequence
-                    masked = masked || (!cells[i].has_seq_id(seq_id));
+                    if (cells.is_empty(i)) {
+                        masked = true;
+                    } else {
+                        const llama_pos p0 = cells.pos_get(i);
 
-                    // mask future tokens
-                    masked = masked || (causal_attn && p0 > p1);
+                        // mask the token if not the same sequence
+                        masked = masked || (!cells.seq_has(i, seq_id));
 
-                    // apply SWA if any
-                    masked = masked || (is_masked_swa(p0, p1));
+                        // mask future tokens
+                        masked = masked || (causal_attn && p0 > p1);
 
-                    float f = 0.0f;
+                        // apply SWA if any
+                        masked = masked || (is_masked_swa(p0, p1));
+
+                        if (!masked && hparams.use_alibi) {
+                            f = -std::abs(p0 - p1);
+                        }
+                    }
 
                     if (masked) {
                         f = -INFINITY;
-                    } else if (hparams.use_alibi) {
-                        f = -std::abs(p0 - p1);
                     }
 
                     data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
 
     int32_t * data = (int32_t *) dst->data;
 
-    for (uint32_t i = 0; i < size; ++i) {
-        data[i] = cells[i].delta;
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
     }
 }
 
@@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
     for (int h = 0; h < 1; ++h) {
         for (int j = 0; j < n_tokens; ++j) {
             for (int i = 0; i < n_kv; ++i) {
-                data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
+                const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
+
+                data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
             }
         }
     }
@@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, size,
+                n_embd_head_k, n_head_kv, cells.size(),
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
                 0);
@@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
             } else {
                 view_v_src = ggml_view_2d(ctx, layer.v,
                         nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, size),
+                        ggml_row_size(layer.v->type, cells.size()),
                         ggml_row_size(layer.v->type, i));
 
                 view_v_dst = ggml_view_2d(ctx, layer.v,
                         nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, size),
+                        ggml_row_size(layer.v->type, cells.size()),
                         ggml_row_size(layer.v->type, id));
             }
 
@@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     const uint32_t n_layer = layers.size();
 
     const uint32_t n_kv   = cell_max();
-    const uint32_t n_used = used;
+    const uint32_t n_used = cells.get_used();
 
     assert(n_used <= n_kv);
 
@@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     ids.resize(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        const auto & cell0 = cells[i0];
-
-        if (!cell0.is_empty()) {
+        if (!cells.is_empty(i0)) {
             ids[i0] = i0;
 
             continue;
@@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
         uint32_t nh = 1;
 
         // determine the size of the hole
-        while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
+        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
             nh++;
         }
 
@@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
 
         // starting from the end, find nh non-empty cells
         for (; is > i0; --is) {
-            const auto & cell1 = cells[is];
-
-            if (cell1.is_empty() || ids[is] != n_kv) {
+            if (cells.is_empty(is) || ids[is] != n_kv) {
                 continue;
             }
 
@@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
 
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
-            auto & cell1 = cells[i1];
-
-            if (cell1.is_empty() || ids[i1] != n_kv) {
+            if (cells.is_empty(i1) || ids[i1] != n_kv) {
                 if (n_moves == max_moves) {
                     stop = true;
                     break;
@@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
             ids[i1] = i0 + nf;
 
             // move the cell meta data
-            cells[i0 + nf] = cell1;
+            cells.mv(i1, i0 + nf);
 
-            // clear the old cell and move the head there
-            cell1 = kv_cell();
             head = n_used;
 
             if (!cont) {
@@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
 }
 
 uint32_t llama_kv_cache_unified::cell_max() const {
-    for (uint32_t i = size; i > 0; --i) {
-        const kv_cell & cell = cells[i - 1];
-
-        if (cell.pos >= 0 && !cell.is_empty()) {
+    for (uint32_t i = cells.size(); i > 0; --i) {
+        if (!cells.is_empty(i - 1)) {
             return i;
         }
     }
@@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const {
 }
 
 bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    if (p0 < 0) {
-        return true;
-    }
+    assert(p0 >= 0 && p1 >= 0);
 
     switch (swa_type) {
         case LLAMA_SWA_TYPE_NONE:
@@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
 
     // Count the number of cells with the specified seq_id
     // Find all the ranges of cells with this seq id (or all, when -1)
-    uint32_t cell_range_begin = size;
-    for (uint32_t i = 0; i < size; ++i) {
-        const auto & cell = cells[i];
-        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+    uint32_t cell_range_begin = cells.size();
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
             ++cell_count;
-            if (cell_range_begin == size) {
+            if (cell_range_begin == cells.size()) {
                 cell_range_begin = i;
             }
         } else {
-            if (cell_range_begin != size) {
+            if (cell_range_begin != cells.size()) {
                 cell_ranges.emplace_back(cell_range_begin, i);
-                cell_range_begin = size;
+                cell_range_begin = cells.size();
             }
         }
     }
-    if (cell_range_begin != size) {
-        cell_ranges.emplace_back(cell_range_begin, size);
+
+    if (cell_range_begin != cells.size()) {
+        cell_ranges.emplace_back(cell_range_begin, cells.size());
     }
 
     // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
@@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
 void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
     for (const auto & range : cell_ranges) {
         for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = cells[i];
-            const llama_pos pos      = cell.pos;
-            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+            std::vector<llama_seq_id> seq_ids;
+
+            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
+                if (cur == seq_id || seq_id == -1) {
+                    if (cells.seq_has(i, cur)) {
+                        seq_ids.push_back(cur);
+                    }
+                }
+            }
+
+            const llama_pos pos     = cells.pos_get(i);
+            const uint32_t n_seq_id = seq_ids.size();
 
             io.write(&pos,      sizeof(pos));
             io.write(&n_seq_id, sizeof(n_seq_id));
 
-            if (n_seq_id) {
-                for (auto seq_id : cell.seq_id) {
-                    io.write(&seq_id, sizeof(seq_id));
-                }
+            for (const auto & seq_id : seq_ids) {
+                io.write(&seq_id, sizeof(seq_id));
             }
         }
     }
@@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
         }
     } else {
         // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = size;
+        const uint32_t kv_size = cells.size();
 
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
@@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             io.read_to(&pos,      sizeof(pos));
             io.read_to(&n_seq_id, sizeof(n_seq_id));
 
-            if (n_seq_id != 0) {
+            if (n_seq_id != 1) {
                 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
                 return false;
             }
 
-            batch.pos[i] = pos;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id[i] = &dest_seq_id;
+            // read the sequence id, but directly discard it - we will use dest_seq_id instead
+            {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+            }
+
+            batch.pos[i]      = pos;
+            batch.n_seq_id[i] = n_seq_id;
+            batch.seq_id[i]   = &dest_seq_id;
         }
 
         if (!find_slot(batch)) {
@@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
 
         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head + cell_count <= size);
-        GGML_ASSERT(cells[head].pos == batch.pos[0]);
-        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+        GGML_ASSERT(head + cell_count <= cells.size());
+        GGML_ASSERT(cells.pos_get(head)                  == batch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.seq_has(head,                  dest_seq_id));
+        GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
     } else {
         // whole KV cache restore
 
-        if (cell_count > size) {
+        if (cell_count > cells.size()) {
             LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
             return false;
         }
@@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         clear();
 
         for (uint32_t i = 0; i < cell_count; ++i) {
-            kv_cell & cell = cells[i];
-
             llama_pos pos;
             uint32_t  n_seq_id;
 
             io.read_to(&pos,      sizeof(pos));
             io.read_to(&n_seq_id, sizeof(n_seq_id));
 
-            cell.pos = pos;
+            cells.pos_set(i, pos);
 
             for (uint32_t j = 0; j < n_seq_id; ++j) {
                 llama_seq_id seq_id;
@@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
                     return false;
                 }
 
-                cell.seq_id.insert(seq_id);
+                cells.seq_add(i, seq_id);
             }
         }
 
         head = 0;
-        used = cell_count;
     }
 
     return true;
@@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
         return false;
     }
-    if (cell_count > size) {
-        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+    if (cell_count > cells.size()) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
         return false;
     }
     if (this->v_trans != (bool) v_trans) {
@@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             if (cell_count) {
                 // For each row in the transposed matrix, read the values for the whole cell range
                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * size) * v_size_el;
+                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
                     ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
                 }
             }
@@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
     kv_swa ->seq_keep(seq_id);
 }
 
-void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    kv_base->seq_add(seq_id, p0, p1, delta);
-    kv_swa ->seq_add(seq_id, p0, p1, delta);
+void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    kv_base->seq_add(seq_id, p0, p1, shift);
+    kv_swa ->seq_add(seq_id, p0, p1, shift);
 }
 
 void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
     }
 }
 
-void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
+void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    if (shift == 0) {
         return;
     }
 
@@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
         if (tail_id >= 0) {
             kv_cell & cell = cells[tail_id];
             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos += delta;
+                cell.pos += shift;
             }
         }
     }
index 191a1090a1252da447fb66b4f21f1b411a02d8c3..86a96820e2420e43f55a439f3d4467290ea7dc35 100644 (file)
@@ -4,6 +4,7 @@
 #include "llama-io.h"
 #include "llama-graph.h"
 #include "llama-memory.h"
+#include "llama-kv-cells.h"
 
 #include "ggml-cpp.h"
 
@@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
     virtual void defrag_sched(float thold) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
+    // TODO: remove
     virtual void set_full() = 0;
 
     //
@@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
     //
 
     // =============================================================================================================
-    // TODO: refactor  and simplify this
+    // TODO: refactor and simplify this [TAG: KV_API]
 
     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
 
@@ -121,7 +123,7 @@ public:
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
     void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -159,7 +161,7 @@ public:
     // llama_kv_cache_unified specific API
     //
 
-    uint32_t get_n() const;
+    uint32_t get_n()    const;
     uint32_t get_size() const;
 
     // get views of the current state of the cache
@@ -180,26 +182,6 @@ private:
     const llama_model & model;
     const llama_hparams & hparams;
 
-    struct kv_cell {
-        llama_pos pos   = -1;
-        llama_pos delta =  0;
-
-        // TODO: replace with bitset uint64_t
-        std::set<llama_seq_id> seq_id;
-
-        bool has_seq_id(const llama_seq_id & id) const {
-            return seq_id.find(id) != seq_id.end();
-        }
-
-        bool is_empty() const {
-            return seq_id.empty();
-        }
-
-        bool is_same_seq(const kv_cell & other) const {
-            return seq_id == other.seq_id;
-        }
-    };
-
     struct kv_layer {
         // layer index in the model
         // note: can be different from the layer index in the KV cache
@@ -209,15 +191,13 @@ private:
         ggml_tensor * v;
     };
 
-    bool has_shift = false;
     bool do_defrag = false;
     bool v_trans   = true;  // the value tensor is transposed
 
     uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
-    uint32_t size = 0; // total number of cells, shared across all sequences
-    uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
 
     // computed before each graph build
+    // TODO: cells should start to maintain this value dynamically based on the edits
     uint32_t n = 0;
 
     const uint32_t n_seq_max = 1;
@@ -233,19 +213,29 @@ private:
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    std::vector<kv_cell>  cells;  // TODO: replace with `struct kv_cells`
+    llama_kv_cells_unified cells;
+
     std::vector<kv_layer> layers;
 
     // model layer id -> KV cache layer id
     std::unordered_map<int32_t, int32_t> map_layer_ids;
 
     // recovery information used to restore the KV cells to their original state in case of a failure
+    // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
+    //       to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
     struct {
         void clear() {
-            cells.clear();
+            states.clear();
         }
 
-        std::unordered_map<uint32_t, kv_cell> cells;
+        struct state {
+            uint32_t i;
+
+            llama_kv_cells_unified cells;
+        };
+
+        // stack with the partial states before each ubatch
+        std::vector<state> states;
     } recovery;
 
     // defrag
@@ -257,6 +247,7 @@ private:
     bool defrag_prepare(int32_t n_max_nodes);
 
     // find how many cells are currently in use
+    // TODO: optimize
     uint32_t cell_max() const;
 
     size_t total_size() const;
@@ -325,7 +316,7 @@ public:
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
     void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -431,7 +422,7 @@ public:
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
     void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h
new file mode 100644 (file)
index 0000000..1385455
--- /dev/null
@@ -0,0 +1,273 @@
+#pragma once
+
+#include "llama.h"
+#include "llama-cparams.h"
+
+#include <bitset>
+#include <cassert>
+#include <vector>
+
+// meta information about KV cells that can be part of multiple sequences at the same time
+// TODO: add unit tests
+class llama_kv_cells_unified {
+public:
+    void reset() {
+        for (uint32_t i = 0; i < pos.size(); ++i) {
+            pos[i]   = -1;
+            shift[i] =  0;
+            seq[i].reset();
+        }
+
+        used      = 0;
+        has_shift = false;
+    }
+
+    void reset_shift() {
+        has_shift = false;
+
+        for (uint32_t i = 0; i < shift.size(); ++i) {
+            shift[i] = 0;
+        }
+    }
+
+    uint32_t size() const {
+        return pos.size();
+    }
+
+    void resize(uint32_t n) {
+        pos.resize(n);
+        shift.resize(n);
+        seq.resize(n);
+
+        reset();
+    }
+
+    bool is_empty(uint32_t i) const {
+        assert(i < pos.size());
+        assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
+
+        return pos[i] == -1;
+    }
+
+    uint32_t get_used() const {
+        return used;
+    }
+
+    bool get_has_shift() const {
+        return has_shift;
+    }
+
+    // move cell isrc to idst (used during defrag)
+    void mv(uint32_t isrc, uint32_t idst) {
+        assert(isrc < pos.size());
+        assert(idst < pos.size());
+
+        pos  [idst] = pos  [isrc];
+        shift[idst] = shift[isrc];
+        seq  [idst] = seq  [isrc];
+
+        pos  [isrc] = -1;
+        shift[isrc] =  0;
+        seq  [isrc].reset();
+    }
+
+    // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
+    llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+        assert(i + n <= pos.size());
+
+        llama_kv_cells_unified res;
+
+        res.resize(n);
+
+        for (uint32_t j = 0; j < n; ++j) {
+            res.pos[j] = pos[i + j];
+            res.seq[j] = seq[i + j];
+
+            assert(shift[i + j] == 0);
+        }
+
+        return res;
+    }
+
+    // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
+    void set(uint32_t i, const llama_kv_cells_unified & other) {
+        assert(i + other.pos.size() <= pos.size());
+
+        for (uint32_t j = 0; j < other.pos.size(); ++j) {
+            if (pos[i + j] == -1 && other.pos[j] != -1) {
+                used++;
+            }
+
+            if (pos[i + j] != -1 && other.pos[j] == -1) {
+                used--;
+            }
+
+            pos[i + j] = other.pos[j];
+            seq[i + j] = other.seq[j];
+
+            assert(shift[i + j] == 0);
+        }
+    }
+
+    // note: call only if the cell has seq_id
+    // return true if the cell becomes empty
+    bool seq_rm(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+        assert(seq[i].test(seq_id));
+        assert(pos[i] != -1);
+        assert(seq_id >= 0);
+
+        seq[i].reset(seq_id);
+
+        if (seq[i].none()) {
+            pos[i] = -1;
+
+            used--;
+
+            return true;
+        }
+
+        return false;
+    }
+
+    // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
+    bool seq_keep(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+
+        if (seq[i].test(seq_id)) {
+            seq[i].reset();
+            seq[i].set(seq_id);
+
+            return false;
+        }
+
+        if (seq[i].any()) {
+            seq[i].reset();
+            pos[i] = -1;
+
+            used--;
+
+            return true;
+        }
+
+        assert(pos[i] == -1);
+
+        return false;
+    }
+
+    bool seq_has(uint32_t i, llama_seq_id seq_id) const {
+        assert(i < pos.size());
+        assert(seq_id >= 0);
+
+        return seq[i].test(seq_id);
+    }
+
+    // note: call only if the cell is not empty and the seq_id is not in the cell
+    void seq_add(uint32_t i, llama_seq_id seq_id) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+        assert(!seq[i].test(seq_id));
+
+        seq[i].set(seq_id);
+    }
+
+    // note: call only if the cell is not empty
+    llama_pos pos_get(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return pos[i];
+    }
+
+    // note: call only if the cell is not empty
+    llama_pos get_shift(uint32_t i) const {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        return shift[i];
+    }
+
+    // check if a cell is not empty and its position is within [p0, p1)
+    bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
+        assert(i < pos.size());
+
+        return pos[i] >= p0 && pos[i] < p1;
+    }
+
+    // set the position of an empty cell
+    // does not modify "has_shift"
+    // note: call only if the cell is empty
+    void pos_set(uint32_t i, llama_pos p) {
+        assert(i < pos.size());
+        assert(pos[i] == -1);
+
+        pos[i] = p;
+        used++;
+    }
+
+    // pos[i] = pos[i] + d
+    // sets "has_shift" to true
+    // note: call only if the cell is not empty
+    bool pos_add(uint32_t i, llama_pos d) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        pos[i]   += d;
+        shift[i] += d;
+
+        has_shift = true;
+
+        if (pos[i] < 0) {
+            pos[i] = -1;
+            seq[i].reset();
+
+            used--;
+
+            return true;
+        }
+
+        return false;
+    }
+
+    // pos[i] = pos[i] / d
+    // sets "has_shift" to true
+    // note: call only if the cell is not empty
+    void pos_div(uint32_t i, int d) {
+        assert(i < pos.size());
+        assert(pos[i] != -1);
+
+        const llama_pos p_old = pos[i];
+
+        pos[i]   /= d;
+        shift[i] += p_old - pos[i];
+
+        has_shift = true;
+    }
+
+private:
+    uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
+
+    bool has_shift = false;
+
+    std::vector<llama_pos> pos;
+
+    // this array accumulates any applied shifts to the pos array since the last reset_shift() call
+    // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
+    //
+    //   cells.pos_add(x, shift_x);
+    //   cells.pos_div(y, shift_y);
+    //   ...
+    //
+    //   if (cells.has_shift()) {
+    //      for (int i = 0; i < n; ++i) {
+    //          auto shift_i = cells.get_shift(i);
+    //          ...
+    //      }
+    //      cells.reset_shift();
+    //   }
+    //
+    std::vector<llama_pos> shift;
+
+    std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
+};
+
index c2571edc715e1c37e51782b22eb6851a8ebe92e6..a2d250434affa8c58fe6cb7639279b34b5397b67 100644 (file)
@@ -22,7 +22,7 @@ public:
     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
     virtual void seq_keep(llama_seq_id seq_id) = 0;
-    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
 
     virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;