]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batch : auto-gen positions + verify multi-sequence input (#14177)
authorGeorgi Gerganov <redacted>
Sun, 15 Jun 2025 06:18:37 +0000 (09:18 +0300)
committerGitHub <redacted>
Sun, 15 Jun 2025 06:18:37 +0000 (09:18 +0300)
* batch : verify multi-sequence input batches

ggml-ci

* cont : auto-gen positions + verify multi-seq input

ggml-ci

* cont : first print debug info, then perform validation

ggml-ci

* cont : fix position auto-gen + add comments

ggml-ci

include/llama.h
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-cparams.h

index 015a57898e22d2411296108fbc7446713f1dc2c9..d5e4cef68c213f0e2bb7fdb54ee84ed58acea6f4 100644 (file)
@@ -243,14 +243,14 @@ extern "C" {
 
     typedef bool (*llama_progress_callback)(float progress, void * user_data);
 
-    // Input data for llama_decode
+    // Input data for llama_encode/llama_decode
     // A llama_batch object can contain input about one or many sequences
     // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
     //
     // - token  : the token ids of the input (used when embd is NULL)
     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
     // - pos    : the positions of the respective token in the sequence
-    //            (if set to NULL, the token position will be tracked automatically by llama_decode)
+    //            (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
     // - seq_id : the sequence to which the respective token belongs
     //            (if set to NULL, the sequence ID will be assumed to be 0)
     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
index bdbf766266f90fef435d06fc035d562adcc2eb81..2265db9b235b8b3f6d61e5ca33156c88fde89348 100644 (file)
@@ -3,6 +3,7 @@
 #include "llama-impl.h"
 #include "llama-cparams.h"
 #include "llama-vocab.h"
+#include "llama-memory.h"
 
 #include <cassert>
 #include <cstring>
@@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
 llama_batch_allocr::llama_batch_allocr() {
     const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
     debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
+
+    seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+    seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+    for (auto & cur : seq_cpl) {
+        cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
+    }
 }
 
-bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
+bool llama_batch_allocr::init(
+        const llama_batch & batch_inp,
+        const llama_vocab & vocab,
+        const llama_memory_i * memory) {
     clear();
 
     batch = batch_inp;
 
     GGML_ASSERT(batch.n_tokens > 0);
 
-    if (!batch.pos) {
-        if (batch.seq_id) {
-            LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
-            return false;
-        }
-    }
+    //
+    // validate input batch
+    //
 
     if (batch.token) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
         }
     }
 
-    if (!batch.pos) {
-        assert(p0 >= 0);
-        pos.resize(batch.n_tokens);
-        for (int32_t i = 0; i < batch.n_tokens; i++) {
-            pos[i] = p0 + i;
-        }
-        batch.pos = pos.data();
-    }
+    //
+    // auto-generate missing fields
+    //
 
     if (!batch.n_seq_id) {
         n_seq_id.resize(batch.n_tokens);
@@ -349,6 +351,32 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
         batch.seq_id = seq_id.data();
     }
 
+    if (!batch.pos) {
+        pos.resize(batch.n_tokens);
+
+        // initialize the starting position for each sequence based on the positions in the memory
+        llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
+        for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            if (!memory) {
+                p0[s] = 0;
+            } else {
+                p0[s] = memory->seq_pos_max(s) + 1;
+            }
+        }
+
+        for (int32_t i = 0; i < batch.n_tokens; i++) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+
+            pos[i] = p0[seq_id];
+
+            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+                p0[batch.seq_id[i][s]] = pos[i] + 1;
+            }
+        }
+
+        batch.pos = pos.data();
+    }
+
     if (!batch.logits) {
         // by default return the output only for the last token
         output.resize(batch.n_tokens);
@@ -356,13 +384,36 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
         batch.logits = output.data();
     }
 
+    //
+    // compute stats
+    //
+
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
         n_outputs += batch.logits[i] != 0;
     }
 
+    // determine coupled sequences
+    // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
+    for (int32_t i = 0; i < batch.n_tokens; ++i) {
+        for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+            seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
+
+            if (s > 0) {
+                const llama_seq_id s0 = batch.seq_id[i][0];
+                const llama_seq_id s1 = batch.seq_id[i][s];
+
+                // mark that sequence s1 is coupled to s0
+                seq_cpl[s1][s0] = true;
+
+                // note: the other way around is not necessary for now
+                //seq_cpl[s0][s1] = true;
+            }
+        }
+    }
+
     if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
-        LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__, batch.n_tokens);
+        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
+        LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__,          batch.n_tokens);
         LLAMA_LOG_DEBUG("%s:   token     = %p\n", __func__, (void *) batch.token);
         LLAMA_LOG_DEBUG("%s:   embd      = %p\n", __func__, (void *) batch.embd);
         LLAMA_LOG_DEBUG("%s:   pos       = %p\n", __func__, (void *) batch.pos);
@@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
                         batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
             }
             LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+
+            LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
+            for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
+                if (seq_pos[s0].empty()) {
+                    continue;
+                }
+
+                std::stringstream ss;
+                for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
+                    if (seq_cpl[s0][s1]) {
+                        ss << s1 << " ";
+                    }
+                }
+
+                LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
+                        __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
+            }
+            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+        }
+    }
+
+    //
+    // consistency checks
+    //
+
+    for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        if (seq_pos[s].empty()) {
+            continue;
+        }
+
+        if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
+            LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
+            return false;
+        }
+
+        if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
+            LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
+            return false;
+        }
+    }
+
+    if (memory) {
+        for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
+            for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
+                if (seq_cpl[s0][s1]) {
+                    if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
+                        memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
+                        LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
+                        return false;
+                    }
+                }
+            }
         }
     }
 
@@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
     return n_outputs;
 }
 
+llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
+    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
+}
+
+llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
+    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
+}
+
 void llama_batch_allocr::clear() {
     n_outputs = 0;
 
@@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
     n_seq_id.clear();
     seq_id.clear();
     output.clear();
+
+    for (auto & cur : seq_pos) {
+        cur.clear();
+    }
+
+    for (auto & cur : seq_cpl) {
+        std::fill(cur.begin(), cur.end(), false);
+    }
 }
 
 //
index 1e0be8ac2c6ce8ad24a433131e91a73a46c30347..04501ce5d424c729749a4ef4696453f1f98a5239 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <array>
 #include <vector>
+#include <set>
 
 // very similar to llama_batch,
 // but has more metadata about sequences
@@ -77,18 +78,25 @@ struct llama_sbatch {
     llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
 };
 
-// temporary allocate memory for the input batch if needed
+// a helper for sanitizing and fulfilling a batch
 class llama_batch_allocr {
 public:
     llama_batch_allocr();
 
-    // optionally fulfill the batch returned by llama_batch_get_one
-    bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
+    // sanitize and auto-gen missing data in the input batch
+    // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
+    bool init(
+            const llama_batch & batch_inp,
+            const llama_vocab & vocab,
+            const llama_memory_i * memory);
 
     const llama_batch & get_batch() const;
 
     uint32_t get_n_outputs() const;
 
+    llama_pos seq_pos_min(llama_seq_id seq_id) const;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const;
+
 private:
     void clear();
 
@@ -103,5 +111,8 @@ private:
     std::vector<llama_seq_id *> seq_id;
     std::vector<int8_t>         output;
 
+    std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
+    std::vector<std::vector<bool>>   seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
+
     int debug;
 };
index ec1e1189b219a44764f1e56afe64393932a19aab..47c60e960dc0199f2e034313ff8a2c506a768ec4 100644 (file)
@@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
         return -1;
     }
 
-    // temporary allocate memory for the input batch if needed
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
+    if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -895,8 +894,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         return -1;
     }
 
-    // temporary allocate memory for the input batch if needed
-    if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
+    if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
index 2871031ef09619bbce252126c1ddd13fb4681dcd..51ebe5d17efa7212b825f47ada5e101883038a9f 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <cstdint>
 
+// TODO: rename to something shorter
 #define LLAMA_MAX_PARALLEL_SEQUENCES 64
 
 struct llama_cparams {