]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cells : fix tracking of seq_pos (#14339)
authorGeorgi Gerganov <redacted>
Mon, 23 Jun 2025 09:27:35 +0000 (12:27 +0300)
committerGitHub <redacted>
Mon, 23 Jun 2025 09:27:35 +0000 (12:27 +0300)
* kv-cells : fix tracking of seq_pos during cache reuse

ggml-ci

* cont : improve error message

ggml-ci

* cont : add more comments

include/llama.h
src/llama-batch.cpp
src/llama-context.cpp
src/llama-kv-cells.h
tools/server/server.cpp

index f4123d14ac1d8e31e298bff9017b92dd853e4d55..3eda9bc68608c0f0b55bca1a74512d4e28876d80 100644 (file)
@@ -944,12 +944,14 @@ extern "C" {
     // Requires the context to have a memory.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    // Upon non-zero return values, the memory state is restored to the state before this call
+    // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
+    //   To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
+    // Upon other return values, the memory state is restored to the state before this call
     //    0 - success
     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
-    //    2 - aborted
+    //    2 - aborted     (processed ubatches will remain in the context's memory)
     //   -1 - invalid input batch
-    // < -1 - error
+    // < -1 - fatal error (processed ubatches will remain in the context's memory)
     LLAMA_API int32_t llama_decode(
             struct llama_context * ctx,
               struct llama_batch   batch);
index b3c996e18ab41183ad1ebcb8a7c2773b32e52ee5..401e11364dbc97dfcb11e1c9d88de45df55c1686 100644 (file)
@@ -245,10 +245,11 @@ bool llama_batch_allocr::init(
         }
 
         if (memory) {
+            bool ok = true;
+
             if (batch.token) {
                 if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
-                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
-                    return false;
+                    ok = false;
                 }
             } else {
                 assert(batch.embd);
@@ -256,10 +257,20 @@ bool llama_batch_allocr::init(
                 // for embeddings (typically used as vision input), we allow them to have repeating positions
                 // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
                 if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
-                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
-                    return false;
+                    ok = false;
                 }
             }
+
+            if (!ok) {
+                LLAMA_LOG_ERROR(
+                        "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
+                        " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
+                        " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
+                        " it is required that the sequence positions remain consecutive: Y = X + 1\n",
+                        __func__, s, s, memory->seq_pos_max(s), s, seq_pos_min(s));
+
+                return false;
+            }
         }
 
         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
index e352d81e4ed7c4d0da6153f67d667d14b8306f14..06e93b19cbf4087b24284ea3a473f12441c31b6a 100644 (file)
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
                 pos_min[s] = std::numeric_limits<llama_pos>::max();
             }
 
-            // TODO: fix sequence indexing
             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
                 const auto & seq_id = ubatch.seq_id[i][0];
 
index 349e9032e2484b71b55c5ae14664d0d498d3d651..c95d635948b5d61190cbaa5a128d339cac64910c 100644 (file)
@@ -7,6 +7,7 @@
 #include <cassert>
 #include <vector>
 #include <set>
+#include <map>
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
@@ -164,7 +165,7 @@ public:
         assert(seq_id >= 0);
 
         seq[i].reset(seq_id);
-        seq_pos[seq_id].erase(pos[i]);
+        seq_pos_dec(seq_id, pos[i]);
 
         if (seq[i].none()) {
             pos[i] = -1;
@@ -187,7 +188,7 @@ public:
             seq[i].reset();
 
             seq[i].set(seq_id);
-            seq_pos[seq_id].insert(pos[i]);
+            seq_pos_inc(seq_id, pos[i]);
 
             return false;
         }
@@ -232,7 +233,7 @@ public:
         assert(!seq[i].test(seq_id));
 
         seq[i].set(seq_id);
-        seq_pos[seq_id].insert(pos[i]);
+        seq_pos_inc(seq_id, pos[i]);
     }
 
     // return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
             return -1;
         }
 
-        return *seq_pos[seq_id].begin();
+        assert(seq_pos[seq_id].begin()->second > 0);
+
+        return seq_pos[seq_id].begin()->first;
     }
 
     // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
             return -1;
         }
 
-        return *seq_pos[seq_id].rbegin();
+        assert(seq_pos[seq_id].rbegin()->second > 0);
+
+        return seq_pos[seq_id].rbegin()->first;
     }
 
     // note: call only if the cell is not empty
@@ -389,17 +394,36 @@ private:
     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
     std::vector<seq_set_t> seq;
 
-    // the set seq_pos[s] tells us which positions are currently present for sequence s
+    // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
+    // if the position p is not present, seq_pos[s][p] is not set
     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
-    std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
+    //
+    // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
+    //  - during performing a cache reuse via (rm + add)
+    //  - some vision models have input embeddings with repeating positions
+    //
+    std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
 
     // helper functions for updating `seq_pos`, once cell at a time:
 
+    void seq_pos_dec(llama_seq_id s, llama_pos p) {
+        auto it = seq_pos[s].find(p);
+        assert(it != seq_pos[s].end());
+
+        if (--it->second == 0) {
+            seq_pos[s].erase(it);
+        }
+    }
+
+    void seq_pos_inc(llama_seq_id s, llama_pos p) {
+        seq_pos[s][p]++;
+    }
+
     // remove cell i
     void seq_pos_rm(uint32_t i) {
         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
-                seq_pos[s].erase(pos[i]);
+                seq_pos_dec(s, pos[i]);
             }
         }
     }
@@ -408,7 +432,7 @@ private:
     void seq_pos_add(uint32_t i) {
         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
-                seq_pos[s].insert(pos[i]);
+                seq_pos_inc(s, pos[i]);
             }
         }
     }
index aa18513e393b467b8065d7af52d2d14dcf51a696..852352383bdbe8eb27ac5acaa81a8f76b5036588 100644 (file)
@@ -3418,9 +3418,12 @@ struct server_context {
                     }
 
                     if (ret < -1) {
+                        // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
                         err = "Compute error.";
                     }
 
+                    // TODO: handle ret == 2 (abort) when we start aborting
+
                     if (!err.empty()) {
                         SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
                         for (auto & slot : slots) {