]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ubatch : new splitting logic (#14217)
authorGeorgi Gerganov <redacted>
Fri, 20 Jun 2025 07:14:14 +0000 (10:14 +0300)
committerGitHub <redacted>
Fri, 20 Jun 2025 07:14:14 +0000 (10:14 +0300)
ggml-ci

19 files changed:
src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-kv-cells.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-memory-recurrent.cpp
src/llama-memory-recurrent.h
src/llama-memory.h
tools/server/server.cpp

index 8b6d14fe8813c3d0874a4ad2aaca71c9e6f0aa4b..b3c996e18ab41183ad1ebcb8a7c2773b32e52ee5 100644 (file)
@@ -1,7 +1,6 @@
 #include "llama-batch.h"
 
 #include "llama-impl.h"
-#include "llama-cparams.h"
 #include "llama-vocab.h"
 #include "llama-memory.h"
 
 #include <algorithm>
 #include <sstream>
 
-llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
-    // clear empty sequences
-    // the previous ubatch is assumed to be gone,
-    // so nothing should refer to values in these sequences anymore.
-    for (size_t i = seq.size(); i-- > 0;) {
-        if (seq[i].length == 0) {
-            seq.pop_back();
-        } else {
-            break;
-        }
-    }
-
-    udatas.push_back({});
-
-    auto & udata = udatas.back();
-
-    udata.token.resize(!has_embd ? n_ubatch : 0);
-    udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
-    udata.pos.resize(n_ubatch);
-    udata.n_seq_id.resize(n_ubatch);
-    udata.seq_id.resize(n_ubatch);
-    udata.output.resize(n_ubatch);
-
-    llama_ubatch ubatch = {
-        /*equal_seqs   =*/ true,
-        /*n_tokens     =*/ 0,
-        /*n_seq_tokens =*/ 0,
-        /*n_seqs       =*/ 0,
-        /*token        =*/ !has_embd ? udata.token.data() : nullptr,
-        /*embd         =*/ has_embd  ? udata.embd.data()  : nullptr,
-        /*pos          =*/ udata.pos.data(),
-        /*n_seq_id     =*/ udata.n_seq_id.data(),
-        /*seq_id       =*/ udata.seq_id.data(),
-        /*output       =*/ udata.output.data(),
-    };
-
-    return ubatch;
-}
-
-void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
-    GGML_ASSERT(batch != nullptr);
-    GGML_ASSERT(length <= seq.length);
-    // Can only add sequences of equal lengths to a batch,
-    // otherwise it isn't clear to which sequence a token belongs
-    GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
-    GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
-    // NOTE: loops are separated for cache-friendliness
-    if (batch->token) {
-        if (ubatch.equal_seqs) {
-            for (size_t i = 0; i < length; ++i) {
-                ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
-            }
-        } else {
-            // simple split
-            ubatch.token = batch->token + seq.offset;
-        }
-    } else {
-        ubatch.token = nullptr;
-    }
-    if (batch->embd) {
-        if (ubatch.equal_seqs) {
-            for (size_t i = 0; i < length; ++i) {
-                memcpy(
-                        ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
-                        batch->embd + (n_embd * ids[seq.offset + i]),
-                        n_embd * sizeof(float)
-                      );
-            }
-        } else {
-            // simple split
-            ubatch.embd = batch->embd + (n_embd * seq.offset);
-        }
-    } else {
-        ubatch.embd = nullptr;
-    }
-    if (ubatch.equal_seqs) {
-        for (size_t i = 0; i < length; ++i) {
-            ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
-        }
-    } else {
-        // simple split
-        ubatch.pos = batch->pos + seq.offset;
-    }
-    if (ubatch.equal_seqs) {
-        ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
-        if (seq.seq_id) {
-            ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
-        }
-    } else {
-        // simple split
-        if (batch->n_seq_id) {
-            ubatch.n_seq_id = batch->n_seq_id + seq.offset;
-        } else {
-            for (size_t i = 0; i < length; ++i) {
-                ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
-            }
-        }
-        if (batch->seq_id) {
-            ubatch.seq_id = batch->seq_id + seq.offset;
-        }
-    }
-    if (batch->logits) {
-        if (ubatch.equal_seqs) {
-            for (size_t i = 0; i < length; ++i) {
-                size_t id = ids[seq.offset + i];
-                int8_t is_output = batch->logits[id];
-                ubatch.output[ubatch.n_tokens + i] = is_output;
-                if (is_output) { out_ids.push_back(id); }
-            }
-        } else {
-            // simple split
-            ubatch.output = batch->logits + seq.offset;
-            for (size_t i = 0; i < length; ++i) {
-                if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
-            }
-        }
-    } else {
-        // only get last output
-        for (size_t i = 0; i < length; ++i) {
-            size_t id = ids[seq.offset + i];
-            int8_t is_last = id == ids.size() - 1;
-            ubatch.output[ubatch.n_tokens + i] = is_last;
-            if (is_last) { out_ids.push_back(id); }
-        }
-    }
-    if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
-        ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
-    }
-    ubatch.n_tokens += length;
-    ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
-    seq.offset += length;
-    seq.length -= length;
-    n_tokens -= length;
-    GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
-}
-
-llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
-    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-    ubatch.equal_seqs = false;
-    if (!seq.empty()) {
-        llama_sbatch_seq & s = seq[0];
-        size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-        GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
-        add_seq_to_ubatch(ubatch, s, length);
-    }
-    return ubatch;
-}
-
-llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
-    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-    if (!seq.empty()) {
-        size_t length = 0;
-        size_t n_tokens_in_ubatch = 0;
-        GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
-                                          // smallest first, because it's easier to split this way;
-                                          // starting from the end to pop in constant time.
-        for (size_t i = seq.size(); i-- > 0;) {
-            llama_sbatch_seq & s = seq[i];
-            GGML_ASSERT(s.length > 0);
-            if (length == 0) {
-                length = s.length < n_ubatch ? s.length : n_ubatch;
-            }
-            add_seq_to_ubatch(ubatch, s, length);
-            n_tokens_in_ubatch += length;
-            // shared prompts can't be mixed with any of their sequences,
-            // so it's safer to compute them in their own ubatch
-            if (s.n_seq_id > 1) { break; }
-            // stop when there isn't enough space for another sequence
-            if (length + n_tokens_in_ubatch > n_ubatch) { break; }
-        }
-    }
-    return ubatch;
-}
-
-llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
-    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-    if (!seq.empty()) {
-        llama_sbatch_seq & s = seq[seq.size() - 1];
-        size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-        GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
-        add_seq_to_ubatch(ubatch, s, length);
-    }
-    return ubatch;
-}
-
-llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
-    GGML_ASSERT(batch.n_tokens >= 0);
-    this->batch = &batch;
-    this->n_embd = n_embd;
-
-    n_tokens = batch.n_tokens;
-    ids.resize(n_tokens);
-    out_ids.clear();
-    // TODO: reserve out_ids and seq
-
-    for (size_t i = 0; i < n_tokens; ++i) {
-        ids[i] = i;
-    }
-
-    if (simple_split) {
-        seq.resize(1);
-        llama_sbatch_seq & s = seq[0];
-        s.n_seq_id = 0;
-        s.seq_id = nullptr;
-        s.offset = 0;
-        s.length = n_tokens;
-        return;
-    }
-
-    std::sort(ids.begin(), ids.end(),
-            [&batch](size_t a, size_t b) {
-                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
-                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
-                // sort by seq_id, then by pos
-                if (n_seq_a == n_seq_b) {
-                    if (batch.seq_id) {
-                        for (int32_t i = 0; i < n_seq_a; ++i) {
-                            llama_seq_id seq_id_a = batch.seq_id[a][i];
-                            llama_seq_id seq_id_b = batch.seq_id[b][i];
-                            // smaller seq_ids go first
-                            if (seq_id_a != seq_id_b) {
-                                return seq_id_a < seq_id_b;
-                            }
-                        }
-                    }
-                    // when all else is equal, sort by pos
-                    if (batch.pos) {
-                        return batch.pos[a] < batch.pos[b];
-                    }
-                    // no pos, sort by id
-                    return a < b;
-                }
-                // shared prompts go first
-                return n_seq_a > n_seq_b;
-            }
-    );
-
-    // init seq
-    llama_sbatch_seq * last_seq = nullptr;
-
-    for (size_t i = 0; i < n_tokens; ++i) {
-        const size_t bi = ids[i];
-        const int32_t n_seqs = batch.n_seq_id[bi];
-        llama_seq_id * seq_ids = batch.seq_id[bi];
-        if (last_seq != nullptr) {
-            bool same = n_seqs == last_seq->n_seq_id;
-            for (int32_t j = 0; same && j < n_seqs; ++j) {
-                if (seq_ids[j] != last_seq->seq_id[j]) {
-                    same = false;
-                }
-            }
-            if (same) {
-                last_seq->length += 1;
-                continue;
-            }
-        }
-        llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
-        seq.push_back(new_seq);
-        last_seq = &seq.back();
-    }
-
-    // keep shared prompts first at the end, then sort by length descending.
-    std::sort(seq.begin(), seq.end(),
-            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
-                if (a.n_seq_id == b.n_seq_id) {
-                    return a.length > b.length;
-                }
-                return a.n_seq_id < b.n_seq_id;
-            }
-            );
-}
-
-llama_batch_allocr::llama_batch_allocr() {
+llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
     const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
     debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
 
@@ -294,17 +18,22 @@ llama_batch_allocr::llama_batch_allocr() {
     for (auto & cur : seq_cpl) {
         cur.resize(LLAMA_MAX_SEQ);
     }
+
+    seq_idx.resize(LLAMA_MAX_SEQ, -1);
 }
 
 bool llama_batch_allocr::init(
         const llama_batch & batch_inp,
         const llama_vocab & vocab,
         const llama_memory_i * memory,
-        bool embd_all) {
+        uint32_t n_embd,
+        bool output_all) {
     clear();
 
     batch = batch_inp;
 
+    this->vocab = &vocab;
+
     GGML_ASSERT(batch.n_tokens > 0);
 
     //
@@ -359,6 +88,7 @@ bool llama_batch_allocr::init(
         llama_pos p0[LLAMA_MAX_SEQ];
         for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (!memory) {
+                // if no memory -> start from 0
                 p0[s] = 0;
             } else {
                 p0[s] = memory->seq_pos_max(s) + 1;
@@ -370,8 +100,11 @@ bool llama_batch_allocr::init(
 
             pos[i] = p0[seq_id];
 
+            // update the starting position for all sequences that are assigned to the this token
             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-                p0[batch.seq_id[i][s]] = pos[i] + 1;
+                const llama_seq_id seq_id = batch.seq_id[i][s];
+
+                p0[seq_id] = pos[i] + 1;
             }
         }
 
@@ -379,7 +112,7 @@ bool llama_batch_allocr::init(
     }
 
     if (!batch.logits) {
-        if (embd_all) {
+        if (output_all) {
             // return the output for all tokens
             output.resize(batch.n_tokens, true);
         } else {
@@ -389,7 +122,7 @@ bool llama_batch_allocr::init(
         }
 
         batch.logits = output.data();
-    } else if (embd_all) {
+    } else if (output_all) {
         bool warn = false;
 
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -410,6 +143,9 @@ bool llama_batch_allocr::init(
     // compute stats
     //
 
+    this->n_embd = n_embd;
+
+    // count the outputs in this batch
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
         n_outputs += batch.logits[i] != 0;
     }
@@ -417,85 +153,86 @@ bool llama_batch_allocr::init(
     // determine coupled sequences
     // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
+        const llama_seq_id s0 = batch.seq_id[i][0];
+
         for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-            seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
+            const llama_seq_id s1 = batch.seq_id[i][s];
 
-            if (s > 0) {
-                const llama_seq_id s0 = batch.seq_id[i][0];
-                const llama_seq_id s1 = batch.seq_id[i][s];
+            seq_pos[s1].insert(batch.pos[i]);
 
+            if (s > 0) {
                 // mark that sequence s1 is coupled to s0
                 seq_cpl[s1][s0] = true;
 
-                // note: the other way around is not necessary for now
+                // note: tracking the other way around is not necessary for now
                 //seq_cpl[s0][s1] = true;
             }
         }
     }
 
-    if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
-        LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__,          batch.n_tokens);
-        LLAMA_LOG_DEBUG("%s:   token     = %p\n", __func__, (void *) batch.token);
-        LLAMA_LOG_DEBUG("%s:   embd      = %p\n", __func__, (void *) batch.embd);
-        LLAMA_LOG_DEBUG("%s:   pos       = %p\n", __func__, (void *) batch.pos);
-        LLAMA_LOG_DEBUG("%s:   n_seq_id  = %p\n", __func__, (void *) batch.n_seq_id);
-        LLAMA_LOG_DEBUG("%s:   seq_id    = %p\n", __func__, (void *) batch.seq_id);
-        LLAMA_LOG_DEBUG("%s:   logits    = %p\n", __func__, (void *) batch.logits);
-        LLAMA_LOG_DEBUG("%s:   n_outputs = %d\n", __func__, n_outputs);
+    // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
+    {
+        seq_set_t seq_set_unq;
 
-        if (debug > 1) {
-            int seq_id_max = 0;
-            for (int32_t i = 0; i < batch.n_tokens; ++i) {
-                for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                    for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                        seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
-                    }
-                }
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            seq_set_t cur;
+            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id = batch.seq_id[i][s];
+
+                cur        .set(seq_id);
+                seq_set_unq.set(seq_id);
             }
-            ++seq_id_max;
 
-            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);
-            for (int32_t i = 0; i < batch.n_tokens; ++i) {
-                std::vector<int8_t> seq_id(seq_id_max);
+            seq_set.push_back(cur);
+            seq_set_map[cur].push_back(i);
+        }
 
-                for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                    seq_id[batch.seq_id[i][s]] = 1;
-                }
+        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            if (seq_set_unq.test(s)) {
+                seq_idx[s] = seq_id_unq.size();
+                seq_id_unq.push_back(s);
+            }
+        }
+    }
 
-                std::stringstream ss;
-                for (int s = 0; s < seq_id_max; ++s) {
-                    if (seq_id[s]) {
-                        ss << s%10;
-                    } else {
-                        ss << ".";
-                    }
-                }
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
 
-                LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
-                        __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
-                        batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
+        llama_ubatch ubatch {
+            /*.equal_seqs   =*/ false,
+            /*.n_tokens     =*/ (uint32_t) batch.n_tokens,
+            /*.n_seq_tokens =*/ (uint32_t) 1,
+            /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
+            /*.n_seqs_unq   =*/ (uint32_t) this->seq_id_unq.size(),
+            /*.token        =*/ batch.token,
+            /*.embd         =*/ batch.embd,
+            /*.pos          =*/ batch.pos,
+            /*.n_seq_id     =*/ batch.n_seq_id,
+            /*.seq_id       =*/ batch.seq_id,
+            /*.seq_id_unq   =*/ this->seq_id_unq.data(),
+            /*.seq_idx      =*/ this->seq_idx.data(),
+            /*.output       =*/ batch.logits,
+        };
+
+        ubatch_print(ubatch, debug);
+
+        LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
+        for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
+            if (seq_pos[s0].empty()) {
+                continue;
             }
-            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
-
-            LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
-            for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
-                if (seq_pos[s0].empty()) {
-                    continue;
-                }
 
-                std::stringstream ss;
-                for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
-                    if (seq_cpl[s0][s1]) {
-                        ss << s1 << " ";
-                    }
+            std::stringstream ss;
+            for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
+                if (seq_cpl[s0][s1]) {
+                    ss << s1 << " ";
                 }
-
-                LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
-                        __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
             }
-            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+
+            LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
+                    __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
         }
+        LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
     }
 
     //
@@ -507,9 +244,22 @@ bool llama_batch_allocr::init(
             continue;
         }
 
-        if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
-            LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
-            return false;
+        if (memory) {
+            if (batch.token) {
+                if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
+                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
+                    return false;
+                }
+            } else {
+                assert(batch.embd);
+
+                // for embeddings (typically used as vision input), we allow them to have repeating positions
+                // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
+                if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
+                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
+                    return false;
+                }
+            }
         }
 
         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
@@ -532,17 +282,120 @@ bool llama_batch_allocr::init(
         }
     }
 
+    // disallow partial sequence sub-sets:
+    //
+    // invalid:          x
+    //            i: 0 1 2 ...
+    // ---------------------------------------
+    // seq_id[i][0]: 0 0 1
+    // seq_id[i][1]: 1 1 2
+    // seq_id[i][2]: 2
+    //
+    // disallow decreasing sequence positions:
+    //
+    // invalid:                  x
+    //            i: 0 1 2 3 4 5 6 ...
+    // ---------------------------------------
+    //       pos[i]: 4 5 0 1 6 2 3
+    // seq_id[i][0]: 0 0 1 1 0 1 0
+    //
+    {
+        seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
+        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            cur_seq_set[s].set();
+        }
+
+        llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
+        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            cur_seq_pos[s] = -1;
+        }
+
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            const llama_pos pos = batch.pos[i];
+
+            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id = batch.seq_id[i][s];
+
+                cur_seq_set[seq_id] &= seq_set[i];
+
+                if (cur_seq_set[seq_id].none()) {
+                    LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
+                    return false;
+                }
+
+                if (pos < cur_seq_pos[seq_id]) {
+                    LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
+                    return false;
+                }
+            }
+        }
+    }
+
+    split_reset();
+
     return true;
 }
 
+llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
+    const uint32_t n_tokens = n_seq_tokens*n_seqs;
+
+    clear();
+    split_reset();
+
+    ubatches.emplace_back();
+
+    auto & ubatch = ubatches.back();
+
+    ubatch.token     .resize(n_tokens);
+    ubatch.embd      .clear();
+    ubatch.pos       .resize(n_tokens);
+    ubatch.n_seq_id  .resize(n_tokens);
+    ubatch.seq_id    .resize(n_tokens);
+    ubatch.seq_id_unq.resize(0);
+    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    ubatch.output    .resize(n_tokens);
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        ubatch.seq_idx[s] = s;
+        ubatch.seq_id_unq.push_back(s);
+    }
+
+    llama_ubatch res {
+        /*.equal_seqs   =*/ true,
+        /*.n_tokens     =*/ n_tokens,
+        /*.n_seq_tokens =*/ n_seq_tokens,
+        /*.n_seqs       =*/ n_seqs,
+        /*.n_seqs_unq   =*/ n_seqs,
+
+        /*.token        =*/ ubatch.token.data(),
+        /*.embd         =*/ nullptr,
+        /*.pos          =*/ ubatch.pos.data(),
+        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
+        /*.seq_id       =*/ ubatch.seq_id.data(),
+        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
+        /*.seq_idx      =*/ ubatch.seq_idx.data(),
+        /*.output       =*/ ubatch.output.data(),
+    };
+
+    return res;
+}
+
 const llama_batch & llama_batch_allocr::get_batch() const {
     return batch;
 }
 
+uint32_t llama_batch_allocr::get_n_tokens() const {
+    return batch.n_tokens;
+}
+
 uint32_t llama_batch_allocr::get_n_outputs() const {
     return n_outputs;
 }
 
+std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
+    return out_ids;
+}
+
 llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
     return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
 }
@@ -551,14 +404,188 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
     return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
 }
 
+void llama_batch_allocr::split_reset() {
+    out_ids.clear();
+
+    used.clear();
+    used.resize(get_n_tokens(), false);
+
+    ubatches.clear();
+}
+
+llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
+    // find the first unused token
+    uint32_t cur_idx = 0;
+    while (cur_idx < used.size() && used[cur_idx]) {
+        ++cur_idx;
+    }
+
+    // we are done
+    if (cur_idx >= used.size()) {
+        return {};
+    }
+
+    std::vector<int32_t> idxs;
+
+    while (true) {
+        idxs.push_back(cur_idx);
+
+        used[cur_idx] = true;
+
+        ++cur_idx;
+
+        if (cur_idx >= used.size()) {
+            break;
+        }
+
+        if (idxs.size() >= n_ubatch) {
+            break;
+        }
+    }
+
+    return ubatch_add(idxs, idxs.size(), false);
+}
+
+llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
+    std::vector<seq_set_t> cur_seq_set;
+
+    // determine the non-overlapping sequence sets participating in this ubatch
+    for (int32_t i = 0; i < batch.n_tokens; ++i) {
+        if (used[i]) {
+            continue;
+        }
+
+        bool add = true;
+
+        for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
+            // no overlap with existing sequence sets:
+            if (!(cur_seq_set[s] & seq_set[i]).none()) {
+                add = false;
+                break;
+            }
+        }
+
+        if (add) {
+            cur_seq_set.push_back(seq_set[i]);
+
+            if (cur_seq_set.size() > n_ubatch) {
+                break;
+            }
+        }
+    }
+
+    const uint32_t n_seqs = cur_seq_set.size();
+
+    // we are done
+    if (n_seqs == 0) {
+        return {};
+    }
+
+    // the current batch index of each sequence set
+    std::vector<int32_t> cur_idx(n_seqs, 0);
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
+            ++cur_idx[s];
+        }
+    }
+
+    // the list of batch indices for each sequence set
+    // at the end we will concat these to get the final ubatch
+    std::vector<idx_vec_t> idxs_per_seq(n_seqs);
+
+    while (true) {
+        // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
+        //   if we haven't reached n_ubatch
+        bool can_expand = true;
+
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
+                can_expand = false;
+                break;
+            }
+        }
+
+        if (!can_expand) {
+            break;
+        }
+
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
+
+            idxs_per_seq[s].push_back(idx);
+
+            used[idx] = true;
+
+            ++cur_idx[s];
+        }
+
+        if  ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
+            break;
+        }
+    }
+
+    // concat the per-sequence-set lists
+    std::vector<int32_t> idxs;
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
+    }
+
+    return ubatch_add(idxs, n_seqs, true);
+}
+
+llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
+    // find the first unused token
+    uint32_t cur_idx = 0;
+    while (cur_idx < used.size() && used[cur_idx]) {
+        ++cur_idx;
+    }
+
+    // we are done
+    if (cur_idx >= used.size()) {
+        return {};
+    }
+
+    // this is the starting sequence set
+    // we allow adding tokens only if their sequence set is a subset of the current sequence set
+    auto cur_seq_set = seq_set[cur_idx];
+
+    std::vector<int32_t> idxs;
+
+    while (true) {
+        idxs.push_back(cur_idx);
+
+        used[cur_idx] = true;
+
+        if (idxs.size() >= n_ubatch) {
+            break;
+        }
+
+        do {
+            ++cur_idx;
+        } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
+
+        if (cur_idx == get_n_tokens()) {
+            break;
+        }
+
+        cur_seq_set = seq_set[cur_idx];
+    }
+
+    return ubatch_add(idxs, 1, true);
+}
+
 void llama_batch_allocr::clear() {
     n_outputs = 0;
 
     batch = {};
-    pos.clear();
-    n_seq_id.clear();
-    seq_id.clear();
-    output.clear();
+
+    pos       .clear();
+    n_seq_id  .clear();
+    seq_id    .clear();
+    seq_id_unq.clear();
+    output    .clear();
 
     for (auto & cur : seq_pos) {
         cur.clear();
@@ -567,6 +594,177 @@ void llama_batch_allocr::clear() {
     for (auto & cur : seq_cpl) {
         std::fill(cur.begin(), cur.end(), false);
     }
+
+    seq_set.clear();
+
+    seq_set_map.clear();
+
+    std::fill(seq_idx.begin(), seq_idx.end(), -1);
+}
+
+llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
+    const uint32_t n_tokens = idxs.size();
+
+    assert(n_tokens%n_seqs == 0);
+
+    ubatches.emplace_back();
+
+    auto & ubatch = ubatches.back();
+
+    const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
+
+    const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
+    const int64_t n_pos_all  =              (int64_t) n_tokens*n_pos_cur;
+
+    ubatch.token     .resize(n_tokens);
+    ubatch.embd      .resize(n_embd_all);
+    ubatch.pos       .resize(n_pos_all);
+    ubatch.n_seq_id  .resize(n_tokens);
+    ubatch.seq_id    .resize(n_tokens);
+    ubatch.seq_id_unq.resize(0);
+    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    ubatch.output    .resize(n_tokens);
+
+    seq_set_t seq_set_unq;
+
+    for (size_t i = 0; i < idxs.size(); ++i) {
+        if (batch.token) {
+            ubatch.token[i] = batch.token[idxs[i]];
+        }
+
+        if (batch.embd) {
+            memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
+        }
+
+        for (int j = 0; j < n_pos_cur; ++j) {
+            ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
+        }
+
+        ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
+        ubatch.seq_id[i]   = batch.seq_id[idxs[i]];
+        ubatch.output[i]   = batch.logits[idxs[i]];
+
+        for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
+            seq_set_unq.set(ubatch.seq_id[i][s]);
+        }
+
+        if (ubatch.output[i]) {
+            out_ids.push_back(idxs[i]);
+        }
+    }
+
+    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        if (seq_set_unq.test(s)) {
+            ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
+            ubatch.seq_id_unq.push_back(s);
+        }
+    }
+
+    llama_ubatch res {
+        /*.equal_seqs   =*/ equal_seqs,
+        /*.n_tokens     =*/ n_tokens,
+        /*.n_seq_tokens =*/ n_tokens/n_seqs,
+        /*.n_seqs       =*/ n_seqs,
+        /*.n_seqs_unq   =*/ (uint32_t) ubatch.seq_id_unq.size(),
+
+        /*.token        =*/ batch.token ? ubatch.token.data() : nullptr,
+        /*.embd         =*/ batch.embd ? ubatch.embd.data() : nullptr,
+        /*.pos          =*/ ubatch.pos.data(),
+        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
+        /*.seq_id       =*/ ubatch.seq_id.data(),
+        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
+        /*.seq_idx      =*/ ubatch.seq_idx.data(),
+        /*.output       =*/ ubatch.output.data(),
+    };
+
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
+
+        ubatch_print(res, debug);
+    }
+
+    return res;
+}
+
+void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs);
+        LLAMA_LOG_DEBUG("%s:   n_tokens     = %d\n", __func__, ubatch.n_tokens);
+        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
+        LLAMA_LOG_DEBUG("%s:   n_seqs       = %d\n", __func__, ubatch.n_seqs);
+        LLAMA_LOG_DEBUG("%s:   n_seqs_unq   = %d\n", __func__, ubatch.n_seqs_unq);
+
+        std::stringstream ss_seq_id_unq;
+        std::stringstream ss_seq_idx;
+
+        ss_seq_id_unq << "[ ";
+        ss_seq_idx << "[";
+
+        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+            ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
+        }
+
+        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            if (ubatch.seq_idx[s] >= 0) {
+                ss_seq_idx << ubatch.seq_idx[s]%10;
+            } else {
+                ss_seq_idx << ".";
+            }
+        }
+
+        ss_seq_id_unq << "]";
+        ss_seq_idx    << "]";
+
+        LLAMA_LOG_DEBUG("%s:   token      = %p\n", __func__, (void *) ubatch.token);
+        LLAMA_LOG_DEBUG("%s:   embd       = %p\n", __func__, (void *) ubatch.embd);
+        LLAMA_LOG_DEBUG("%s:   pos        = %p\n", __func__, (void *) ubatch.pos);
+        LLAMA_LOG_DEBUG("%s:   n_seq_id   = %p\n", __func__, (void *) ubatch.n_seq_id);
+        LLAMA_LOG_DEBUG("%s:   seq_id     = %p\n", __func__, (void *) ubatch.seq_id);
+        LLAMA_LOG_DEBUG("%s:   seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
+        LLAMA_LOG_DEBUG("%s:   seq_idx    = %s\n", __func__, ss_seq_idx.str().c_str());
+        LLAMA_LOG_DEBUG("%s:   output     = %p\n", __func__, (void *) ubatch.output);
+        LLAMA_LOG_DEBUG("%s:   n_outputs  = %d\n", __func__, n_outputs);
+
+        if (debug > 1) {
+            int seq_id_max = 0;
+            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
+                    for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
+                        seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
+                    }
+                }
+            }
+            ++seq_id_max;
+
+            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);
+            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+                std::vector<int8_t> seq_id(seq_id_max);
+
+                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
+                    seq_id[ubatch.seq_id[i][s]] = 1;
+                }
+
+                std::stringstream ss;
+                for (int s = 0; s < seq_id_max; ++s) {
+                    if (seq_id[s]) {
+                        ss << s%10;
+                    } else {
+                        ss << ".";
+                    }
+                }
+
+                if (ubatch.token) {
+                    LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
+                            __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
+                            ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
+                } else {
+                    LLAMA_LOG_DEBUG("%s:  %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
+                            __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
+                }
+            }
+            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+        }
+    }
 }
 
 //
@@ -577,25 +775,25 @@ struct llama_batch llama_batch_get_one(
              llama_token * tokens,
                  int32_t   n_tokens) {
     return {
-        /*n_tokens       =*/ n_tokens,
-        /*tokens         =*/ tokens,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
+        /*n_tokens =*/ n_tokens,
+        /*tokens   =*/ tokens,
+        /*embd     =*/ nullptr,
+        /*pos      =*/ nullptr,
+        /*n_seq_id =*/ nullptr,
+        /*seq_id   =*/ nullptr,
+        /*logits   =*/ nullptr,
     };
 }
 
 struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
     llama_batch batch = {
-        /*n_tokens       =*/ 0,
-        /*tokens         =*/ nullptr,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
+        /*n_tokens =*/ 0,
+        /*tokens   =*/ nullptr,
+        /*embd     =*/ nullptr,
+        /*pos      =*/ nullptr,
+        /*n_seq_id =*/ nullptr,
+        /*seq_id   =*/ nullptr,
+        /*logits   =*/ nullptr,
     };
 
     if (embd) {
index a555c157234be82933b0bf7b35b43c4fed2593e4..d2c5376188a0bd985ef7039cea5231e047198b91 100644 (file)
@@ -2,86 +2,44 @@
 
 #include "llama.h"
 
+#include "llama-cparams.h"
+
 #include <array>
 #include <vector>
 #include <set>
+#include <bitset>
+#include <unordered_map>
 
-// very similar to llama_batch,
-// but has more metadata about sequences
+// keep this struct lightweight
+// it points to data in `llama_batch_allocr`
 struct llama_ubatch {
     bool equal_seqs;
     // TODO: whole_seqs for embeddings?
 
     uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
-    uint32_t n_seq_tokens; // tokens per sequence
-    uint32_t n_seqs;
-
-    llama_token  *  token;    // [n_tokens]
-    float        *  embd;     // [n_embd, n_tokens]
-    llama_pos    *  pos;      // [n_tokens]
-    int32_t      *  n_seq_id; // [n_seqs]
-    llama_seq_id ** seq_id;   // [n_seqs]
-    int8_t       *  output;   // [n_tokens]
-};
-
-struct llama_sbatch_seq {
-    int32_t n_seq_id;
-
-    llama_seq_id * seq_id;
-
-    size_t offset;
-    size_t length;
-};
-
-// sequence-length-aware batch splitting
-struct llama_sbatch {
-    // tokens left in this batch
-    size_t n_tokens;
-
-    size_t n_embd;
-
-    // sorted indices into the batch
-    std::vector<int64_t> ids;
-    // batch indices of the output
-    std::vector<int64_t> out_ids;
-    std::vector<llama_sbatch_seq> seq;
-
-    const llama_batch * batch = nullptr;
-
-    // buffers for the ubatches
-    // TODO: very hacky, this needs a complete rework
-    struct ubatch_data {
-        std::vector<llama_token>    token;
-        std::vector<float>          embd;
-        std::vector<llama_pos>      pos;
-        std::vector<int32_t>        n_seq_id;
-        std::vector<llama_seq_id *> seq_id;
-        std::vector<int8_t>         output;
-    };
-
-    std::vector<ubatch_data> udatas;
-
-    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
-
-    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
-
-    // simple split, unknown number of sequences of unequal lengths
-    llama_ubatch split_simple(size_t n_ubatch);
-
-    // make batches of equal-length sequences
-    llama_ubatch split_equal(size_t n_ubatch);
-
-    // sequence-wise split
-    llama_ubatch split_seq(size_t n_ubatch);
-
-    llama_sbatch() = default;
-    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
+    uint32_t n_seq_tokens; // tokens per sequence set
+    uint32_t n_seqs;       // sequence sets in the ubatch
+    uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
+
+    // seq_id_unq: unique sequence ids in the ubatch
+    // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
+    //             used for extracting sequence pooled embeddings
+
+    //                          // size               | idx | val
+    llama_token  *  token;      // [n_tokens]         | i   | id, token
+    float        *  embd;       // [n_embd, n_tokens] | i   | embd
+    llama_pos    *  pos;        // [n_tokens]         | i   | pos
+    int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
+    llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
+    llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
+    int32_t      *  seq_idx;    // [LLAMA_MAX_SEQ]    | -   | seq_idx
+    int8_t       *  output;     // [n_tokens]         | i   | -
 };
 
-// a helper for sanitizing and fulfilling a batch
+// a helper for sanitizing, fulfilling and splitting a batch
 class llama_batch_allocr {
 public:
-    llama_batch_allocr();
+    llama_batch_allocr(uint32_t n_pos_per_embd);
 
     // sanitize and auto-gen missing data in the input batch
     // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@@ -89,20 +47,57 @@ public:
             const llama_batch & batch_inp,
             const llama_vocab & vocab,
             const llama_memory_i * memory,
-            bool embd_all);
+            uint32_t n_embd,
+            bool output_all);
 
     const llama_batch & get_batch() const;
 
+    uint32_t get_n_tokens()  const;
     uint32_t get_n_outputs() const;
 
+    // the array of output indices in the order they were encountered during the ubatch splitting
+    std::vector<int32_t> & get_out_ids();
+
+    // min/max positions of each sequence in the current ubatch
     llama_pos seq_pos_min(llama_seq_id seq_id) const;
     llama_pos seq_pos_max(llama_seq_id seq_id) const;
 
+    // call once before splitting the batch to reset the internal state
+    void split_reset();
+
+    // simple split, unknown number of sequence sets of unequal lengths
+    llama_ubatch split_simple(uint32_t n_ubatch);
+
+    // make ubatches of equal-length sequences sets
+    llama_ubatch split_equal(uint32_t n_ubatch);
+
+    // sequence-set-wise split - each ubatch contains a single sequence-set
+    llama_ubatch split_seq(uint32_t n_ubatch);
+
+    // a helper method for creating a well-defined ubatch of tokens
+    // TODO: support embeddings if needed in the future
+    llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
+
 private:
     void clear();
 
+    // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
+    // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
+    llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
+
+    // for debugging, start with LLAMA_BATCH_DEBUG=2
+    void ubatch_print(const llama_ubatch & ubatch, int debug);
+
     llama_batch batch;
 
+    // only for debugging purposes
+    const llama_vocab * vocab;
+
+    // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
+    //       ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
+    const uint32_t n_pos_per_embd;
+
+    uint32_t n_embd;
     uint32_t n_outputs;
 
     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -110,10 +105,43 @@ private:
     std::vector<llama_pos>      pos;
     std::vector<int32_t>        n_seq_id;
     std::vector<llama_seq_id *> seq_id;
+    std::vector<llama_seq_id>   seq_id_unq;
+    std::vector<int32_t>        seq_idx;
     std::vector<int8_t>         output;
 
-    std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
-    std::vector<std::vector<bool>>   seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
+    using pos_set_t = std::set<llama_pos>;
+    using seq_cpl_t = std::vector<bool>;
+
+    std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
+    std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
+
+    using idx_vec_t = std::vector<int32_t>;
+    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
+
+    std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
+
+    std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
+
+    // batch indices of the output
+    std::vector<int32_t> out_ids;
+
+    // used[i] indicates if token i has already been used in a previous ubatch
+    std::vector<bool> used;
+
+    // llama_ubatch points to this data:
+    struct ubatch {
+        std::vector<llama_token>    token;
+        std::vector<float>          embd;
+        std::vector<llama_pos>      pos;
+        std::vector<int32_t>        n_seq_id;
+        std::vector<llama_seq_id *> seq_id;
+        std::vector<llama_seq_id>   seq_id_unq;
+        std::vector<int32_t>        seq_idx;
+        std::vector<int8_t>         output;
+    };
+
+    // current splitting state:
+    std::vector<ubatch> ubatches;
 
     int debug;
 };
index f56a58e9b6ec6208c92e0e38c4a879c5af44937c..5a18a4fb3939a10082f2ecb2a4f17ec9815a1320 100644 (file)
@@ -20,7 +20,7 @@ llama_context::llama_context(
         const llama_model & model,
               llama_context_params params) :
     model(model),
-    batch_allocr(std::make_unique<llama_batch_allocr>()) {
+    balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
     LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
 
     t_start_us = model.t_start_us;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
 }
 
 int llama_context::encode(const llama_batch & batch_inp) {
+    GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
+
     if (batch_inp.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    const auto & hparams = model.hparams;
+
+    const int64_t n_embd = hparams.n_embd;
+
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
+    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
 
-    const llama_batch & batch = batch_allocr->get_batch();
+    const uint32_t n_tokens = balloc->get_n_tokens();
 
-    const uint32_t n_tokens = batch.n_tokens;
-
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    const llama_ubatch ubatch = balloc->split_simple(n_tokens);
 
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
     GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
     n_queued_tokens += n_tokens;
 
-    const auto & hparams = model.hparams;
-
-    const int64_t n_embd = hparams.n_embd;
-
-    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
-
-    const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
-
     // reserve output buffer
     if (output_reserve(n_tokens) < n_tokens) {
         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
                 {
                     // extract sequence embeddings
                     auto & embd_seq_out = embd_seq;
-                    embd_seq_out.clear();
 
-                    GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
+                    for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                        const llama_seq_id seq_id  = ubatch.seq_id_unq[s];
+                        const int32_t      seq_idx = ubatch.seq_idx[seq_id];
 
-                    // TODO: fix indexing [UBATCH_IDX]
-                    for (uint32_t i = 0; i < n_tokens; i++) {
-                        const llama_seq_id seq_id = ubatch.seq_id[i][0];
-                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                            continue;
-                        }
                         embd_seq_out[seq_id].resize(n_embd);
-                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
                     }
                 } break;
             case LLAMA_POOLING_TYPE_RANK:
                 {
                     // extract the rerank score - n_cls_out floats per sequence
                     auto & embd_seq_out = embd_seq;
+
                     const uint32_t n_cls_out = hparams.n_cls_out;
 
-                    // TODO: fix indexing [UBATCH_IDX]
-                    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                        const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                            continue;
-                        }
+                    for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                        const llama_seq_id seq_id  = ubatch.seq_id_unq[s];
+                        const int32_t      seq_idx = ubatch.seq_idx[seq_id];
+
                         embd_seq_out[seq_id].resize(n_cls_out);
-                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
                     }
                 } break;
             case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
         cross.v_embd.resize(cross.n_embd*cross.n_enc);
         memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
 
+        const auto & batch = balloc->get_batch();
+
         // remember the sequence ids used during the encoding - needed for cross attention later
         cross.seq_ids_enc.resize(n_tokens);
         for (uint32_t i = 0; i < n_tokens; i++) {
             cross.seq_ids_enc[i].clear();
+
             for (int s = 0; s < batch.n_seq_id[i]; s++) {
-                llama_seq_id seq_id = batch.seq_id[i][s];
+                const llama_seq_id seq_id = batch.seq_id[i][s];
+
                 cross.seq_ids_enc[i].insert(seq_id);
             }
         }
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
 }
 
 int llama_context::decode(const llama_batch & batch_inp) {
+    GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
+
     if (!memory) {
         LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
         return encode(batch_inp);
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
         return -1;
     }
 
-    // when computing embeddings, all tokens are output
-    const bool embd_all = cparams.embeddings;
-
-    if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
-        LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
-        return -1;
-    }
-
-    const llama_batch & batch = batch_allocr->get_batch();
-
     const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
     const int32_t n_vocab = vocab.n_tokens();
     const int64_t n_embd  = hparams.n_embd;
 
-    const uint32_t n_tokens_all = batch.n_tokens;
+    // when computing embeddings, all tokens are output
+    const bool output_all = cparams.embeddings;
 
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
+        LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+        return -1;
+    }
 
-    const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
+    const uint32_t n_tokens_all  = balloc->get_n_tokens();
+    const uint32_t n_outputs_all = balloc->get_n_outputs();
 
-    if (embd_all) {
+    if (output_all) {
         // require that all tokens are output
         if (n_outputs_all != n_tokens_all) {
             LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     llama_memory_state_ptr mstate;
 
     while (true) {
-        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
+        mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
         if (!mstate) {
             return -2;
         }
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
                         did_optimize = true;
 
                         if (kv_self_update(true)) {
-                            LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
+                            LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
 
                             continue;
                         }
                     }
 
-                    LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
+                    LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
 
                     return 1;
                 }
             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
                 {
-                    LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
+                    LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
 
                     return -2;
                 }
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
             if (n_outputs_all == n_tokens_all) {
                 n_outputs_new = ubatch.n_tokens;
             } else {
-                GGML_ASSERT(ubatch.output);
                 for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
                     n_outputs_new += (int32_t) (ubatch.output[i] != 0);
                 }
@@ -1105,27 +1095,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
                         // extract sequence embeddings (cleared before processing each batch)
                         auto & embd_seq_out = embd_seq;
 
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
+                        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                            const llama_seq_id seq_id  = ubatch.seq_id_unq[s];
+                            const int32_t      seq_idx = ubatch.seq_idx[seq_id];
+
                             embd_seq_out[seq_id].resize(n_embd);
-                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
                 case LLAMA_POOLING_TYPE_RANK:
                     {
-                        // extract the rerank score - a single float per sequence
+                        // extract the rerank score - n_cls_out floats per sequence
                         auto & embd_seq_out = embd_seq;
 
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(1);
-                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        const uint32_t n_cls_out = hparams.n_cls_out;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                            const llama_seq_id seq_id  = ubatch.seq_id_unq[s];
+                            const int32_t      seq_idx = ubatch.seq_idx[seq_id];
+
+                            embd_seq_out[seq_id].resize(n_cls_out);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
                         }
                     } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1145,7 +1135,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     if (n_outputs > 0) {
         bool sorted_output = true;
 
-        auto & out_ids = mstate->out_ids();
+        auto & out_ids = balloc->get_out_ids();
 
         GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
 
@@ -1318,8 +1308,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
 
     this->n_outputs = n_outputs;
 
-    llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-    llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+    llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
+    llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
     auto * gf = graph_init();
     auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
@@ -2039,7 +2029,12 @@ void llama_context::opt_epoch_iter(
             batch.logits  [pos_batch]    = true;
         }
 
-        const auto n_tokens_all = batch.n_tokens;
+        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
+            LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+            return;
+        }
+
+        const uint32_t n_tokens_all = balloc->get_n_tokens();
 
         n_queued_tokens += n_tokens_all;
 
@@ -2047,7 +2042,7 @@ void llama_context::opt_epoch_iter(
 
         uint32_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
+        auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
index 040f03ae42e65b242f05f7e7f49a6cea577eb6ba..7d300c14572e9ed2b2009f4a11f69c36eafc3303 100644 (file)
@@ -247,7 +247,7 @@ private:
     std::map<llama_seq_id, std::vector<float>> embd_seq;
 
     // reuse the batch_allocr to avoid unnecessary memory allocations
-    std::unique_ptr<llama_batch_allocr> batch_allocr;
+    std::unique_ptr<llama_batch_allocr> balloc;
 
     uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
 
index 65d98cbbb3987036a27180d61f8f7083114d1455..083366fd68d074d73e274f53da458ca6750d2f7c 100644 (file)
@@ -130,110 +130,97 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens     = ubatch->n_tokens;
         const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-        const int64_t n_seqs       = ubatch->n_seqs;
+        const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
 
         GGML_ASSERT(mean);
         GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
 
         float * data = (float *) mean->data;
-        memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
+        memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
 
-        std::vector<uint64_t> sum(n_tokens, 0);
+        std::vector<uint64_t> sums(n_seqs_unq, 0);
+        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
+            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
+                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 
-        // TODO: fix indexing [UBATCH_IDX]
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch->seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
-
-            sum[seq_id] += ubatch->n_seq_tokens;
+                sums[seq_idx] += ubatch->n_seq_tokens;
+            }
         }
 
-        std::vector<float> div(n_tokens, 0.0f);
-        for (int i = 0; i < n_tokens; ++i) {
-            const uint64_t s = sum[i];
-            if (s > 0) {
-                div[i] = 1.0f/float(s);
+        std::vector<float> div(n_seqs_unq, 0.0f);
+        for (int s = 0; s < n_seqs_unq; ++s) {
+            const uint64_t sum = sums[s];
+            if (sum > 0) {
+                div[s] = 1.0f/float(sum);
             }
         }
 
-        // TODO: fix indexing [UBATCH_IDX]
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
+            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
+                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+                for (int j = 0; j < n_seq_tokens; ++j) {
+                    data[seq_idx*n_tokens + i + j] = div[seq_idx];
+                }
             }
         }
     }
 }
 
 void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
-    if (cparams.embeddings && (
-                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
-                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
-        const int64_t n_tokens     = ubatch->n_tokens;
-        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-        const int64_t n_seqs       = ubatch->n_seqs;
+    const int64_t n_tokens     = ubatch->n_tokens;
+    const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+    const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
 
+    if (cparams.embeddings && (
+            cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
+            cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
+        )) {
         GGML_ASSERT(cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
 
         uint32_t * data = (uint32_t *) cls->data;
-        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 
-        // TODO: fix indexing [UBATCH_IDX]
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
+            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
+                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 
-            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
-
-                if (pos == 0) {
-                    data[seq_id] = s*n_seq_tokens + i;
-                }
+                data[seq_idx] = i;
             }
         }
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = ubatch->n_tokens;
-        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-        const int64_t n_seqs       = ubatch->n_seqs;
-
         GGML_ASSERT(cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
 
         uint32_t * data = (uint32_t *) cls->data;
-        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 
-        std::vector<int> last_pos(n_tokens, -1);
-        std::vector<int> last_row(n_tokens, -1);
+        std::vector<int> last_pos(n_seqs_unq, -1);
+        std::vector<int> last_row(n_seqs_unq, -1);
 
-        // TODO: fix indexing [UBATCH_IDX]
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch->seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_pos pos = ubatch->pos[i];
 
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
+            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
+                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 
-                if (pos >= last_pos[seq_id]) {
-                    last_pos[seq_id] = pos;
-                    last_row[seq_id] = s*n_seq_tokens + i;
+                if (pos >= last_pos[seq_idx]) {
+                    last_pos[seq_idx] = pos;
+                    last_row[seq_idx] = i;
                 }
             }
         }
 
-        for (int i = 0; i < n_tokens; ++i) {
-            if (last_row[i] >= 0) {
-                data[i] = last_row[i];
+        for (int s = 0; s < n_seqs_unq; ++s) {
+            if (last_row[s] >= 0) {
+                data[s] = last_row[s];
             }
         }
     }
@@ -266,89 +253,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
-    if (kq_mask) {
-        if (cparams.causal_attn) {
-            const int64_t n_kv         = ubatch->n_tokens;
-            const int64_t n_tokens     = ubatch->n_tokens;
-            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-            const int64_t n_seqs       = ubatch->n_seqs;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
-            float * data = (float *) kq_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                // TODO: fix indexing [UBATCH_IDX]
-                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
-                                    if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
-                            }
-                        }
-                    }
-                }
-            }
-        } else {
-            const int64_t n_tokens     = ubatch->n_tokens;
-            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-            const int64_t n_seqs       = ubatch->n_seqs;
-            const int64_t n_stride     = ubatch->n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
-
-            float * data = (float *) kq_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                // TODO: fix indexing [UBATCH_IDX]
-                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
-                                    if (ubatch->seq_id[s0][s] == seq_id) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
-                            }
-                        }
+    const int64_t n_kv     = ubatch->n_tokens;
+    const int64_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(kq_mask);
+    GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
+
+    float * data = (float *) kq_mask->data;
+
+    for (int h = 0; h < 1; ++h) {
+        for (int i1 = 0; i1 < n_tokens; ++i1) {
+            const llama_seq_id s1 = ubatch->seq_id[i1][0];
+
+            for (int i0 = 0; i0 < n_tokens; ++i0) {
+                float f = -INFINITY;
+
+                for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
+                    const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
-                        for (int i = n_tokens; i < n_stride; ++i) {
-                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                    // TODO: reimplement this like in llama_kv_cache_unified
+                    if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
+                        if (hparams.use_alibi) {
+                            f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
+                        } else {
+                            f = 0.0f;
                         }
+                        break;
                     }
                 }
+
+                data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
             }
         }
     }
@@ -371,34 +305,36 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
 }
 
 void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
-    if (cross_kq_mask) {
-        const int64_t n_enc    = cross_kq_mask->ne[0];
-        const int64_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(cross_kq_mask);
 
-        GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
-        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+    const int64_t n_enc    = cross_kq_mask->ne[0];
+    const int64_t n_tokens = ubatch->n_tokens;
 
-        float * data = (float *) cross_kq_mask->data;
+    GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
+    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
 
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_enc; ++i) {
-                    float f = -INFINITY;
-                    // TODO: fix indexing [UBATCH_IDX]
-                    for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = ubatch->seq_id[j][s];
-                        if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
-                            f = 0.0f;
-                        }
+    float * data = (float *) cross_kq_mask->data;
+
+    for (int h = 0; h < 1; ++h) {
+        for (int i = 0; i < n_tokens; ++i) {
+            for (int j = 0; j < n_enc; ++j) {
+                float f = -INFINITY;
+
+                for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                    const llama_seq_id seq_id = ubatch->seq_id[i][s];
+
+                    if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
+                        f = 0.0f;
                     }
-                    data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
                 }
+
+                data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
             }
+        }
 
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (int j = 0; j < n_enc; ++j) {
-                    data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
-                }
+        for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+            for (int j = 0; j < n_enc; ++j) {
+                data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
             }
         }
     }
@@ -467,10 +403,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     res              (std::make_unique<llm_graph_result>()) {
     }
 
-int64_t llm_graph_context::n_pos_per_embd() const {
-    return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
-}
-
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
     if (cb_func) {
         cb_func(ubatch, cur, name, il);
@@ -915,11 +847,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos() const {
-    auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
+    auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
 
     auto & cur = inp->pos;
 
-    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
     ggml_set_input(cur);
 
     res->add_input(std::move(inp));
@@ -959,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
 
     auto & cur = inp->mean;
 
-    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
     ggml_set_input(cur);
 
     res->add_input(std::move(inp));
@@ -972,7 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
 
     auto & cur = inp->cls;
 
-    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
     ggml_set_input(cur);
 
     res->add_input(std::move(inp));
index 58845e284abed87f06e4cb5f15035a42a521f316..9e62fa60720d78b9f3aded7ff29bf1d4212197bb 100644 (file)
@@ -95,14 +95,14 @@ public:
 
 class llm_graph_input_pos : public llm_graph_input_i {
 public:
-    llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
+    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
     virtual ~llm_graph_input_pos() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * pos = nullptr; // I32 [n_batch]
 
-    const int64_t n_pos_per_embd = 1;
+    const uint32_t n_pos_per_embd = 1;
 };
 
 // temperature tuning, used by llama4
@@ -464,8 +464,6 @@ struct llm_graph_context {
 
     llm_graph_context(const llm_graph_params & params);
 
-    int64_t n_pos_per_embd() const;
-
     void cb(ggml_tensor * cur, const char * name, int il) const;
 
     //
index b40566ced99eed9c9554182bacf5d7f738125e3c..bba7a12dc5496eaa99c82a904c118af353fe384c 100644 (file)
@@ -90,6 +90,10 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
     return recurrent_layer_arr[il];
 }
 
+uint32_t llama_hparams::n_pos_per_embd() const {
+    return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
+}
+
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
         return swa_layers[il];
index 82bb5b60849460a37bfcf74765daed7e585345ca..7b315a9a74b1da6669ba64eb2c6efe3ab88c8d81 100644 (file)
@@ -192,6 +192,8 @@ struct llama_hparams {
     // whether or not the given layer is recurrent (for hybrid models)
     bool is_recurrent(uint32_t il) const;
 
+    uint32_t n_pos_per_embd() const;
+
     bool is_swa(uint32_t il) const;
 };
 
index a869b1de8c2a321bfa3ed0945a726de83d48a244..0ced340dec6c5c787f767163d4d3849785f31d61 100644 (file)
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     GGML_UNUSED(embd_all);
 
     // first try simple split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
+        balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_simple(n_ubatch);
 
-        while (sbatch.n_tokens > 0) {
-            auto ubatch = sbatch.split_simple(n_ubatch);
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
 
-            ubatches.push_back(ubatch);
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
         auto heads_base = kv_base->prepare(ubatches);
@@ -123,19 +126,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
         assert(heads_base.size() == heads_swa.size());
 
         return std::make_unique<llama_kv_cache_unified_iswa_state>(
-                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+                this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
     // if it fails, try equal split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+        balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_equal(n_ubatch);
 
-        while (sbatch.n_tokens > 0) {
-            auto ubatch = sbatch.split_equal(n_ubatch);
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
 
-            ubatches.push_back(ubatch);
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
         auto heads_base = kv_base->prepare(ubatches);
@@ -151,7 +157,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
         assert(heads_base.size() == heads_swa.size());
 
         return std::make_unique<llama_kv_cache_unified_iswa_state>(
-                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+                this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
     // TODO: if we fail again, we should attempt different splitting strategies
@@ -214,15 +220,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
         llama_kv_cache_unified_iswa * kv,
-        llama_sbatch sbatch,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
         std::vector<llama_ubatch> ubatches) :
-    sbatch(std::move(sbatch)),
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
-    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)),
+    state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
+    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
     status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
@@ -252,12 +256,6 @@ bool llama_kv_cache_unified_iswa_state::apply() {
     return res;
 }
 
-std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return sbatch.out_ids;
-}
-
 llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
     return status;
 }
index 813eaf39b25b0202b8a7c9b463f4d882112c1ee8..071041585db38588217547ef556470ddbae2fdc1 100644 (file)
@@ -32,7 +32,7 @@ public:
     //
 
     llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
@@ -90,7 +90,6 @@ public:
     // used to create a state from a batch
     llama_kv_cache_unified_iswa_state(
             llama_kv_cache_unified_iswa * kv,
-            llama_sbatch sbatch,
             std::vector<uint32_t> heads_base,
             std::vector<uint32_t> heads_swa,
             std::vector<llama_ubatch> ubatches);
@@ -104,8 +103,6 @@ public:
     bool next()  override;
     bool apply() override;
 
-    std::vector<int64_t> & out_ids() override;
-
     llama_memory_status  get_status() const override;
     const llama_ubatch & get_ubatch() const override;
 
@@ -119,8 +116,6 @@ public:
 private:
     //llama_kv_cache_unified_iswa * kv;
 
-    llama_sbatch sbatch;
-
     // the index of the next ubatch to process
     size_t i_next = 0;
 
index d4412288925c3d2d4bb67996531e3f486ccc7b98..6897b797153dbe894b8d66cf9f626463efe47a9c 100644 (file)
@@ -308,17 +308,23 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 }
 
 llama_memory_state_ptr llama_kv_cache_unified::init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) {
     GGML_UNUSED(embd_all);
 
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
+        balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
-        while (sbatch.n_tokens > 0) {
-            ubatches.push_back(sbatch.split_simple(n_ubatch));
+        while (true) {
+            auto ubatch = balloc.split_simple(n_ubatch);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
         auto heads = prepare(ubatches);
@@ -327,7 +333,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
         }
 
         return std::make_unique<llama_kv_cache_unified_state>(
-                this, std::move(sbatch), std::move(heads), std::move(ubatches));
+                this, std::move(heads), std::move(ubatches));
     } while (false);
 
     return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 }
 
 void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
-    if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
-        LLAMA_LOG_DEBUG("%s:   n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
-        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
-    }
-
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
         seq_pos_max_rm[s] = -1;
     }
 
-    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-        for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
-            const uint32_t idx = s*ubatch.n_seq_tokens + j;
-
-            if (!cells.is_empty(head_cur + idx)) {
-                assert(cells.seq_count(head_cur + idx) == 1);
+    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+        if (!cells.is_empty(head_cur + i)) {
+            assert(cells.seq_count(head_cur + i) == 1);
 
-                const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
-                const llama_pos    pos    = cells.pos_get(head_cur + idx);
+            const llama_seq_id seq_id = cells.seq_get(head_cur + i);
+            const llama_pos    pos    = cells.pos_get(head_cur + i);
 
-                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+            seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 
-                cells.rm(head_cur + idx);
-            }
+            cells.rm(head_cur + i);
+        }
 
-            cells.pos_set(head_cur + idx, ubatch.pos[idx]);
+        cells.pos_set(head_cur + i, ubatch.pos[i]);
 
-            // TODO: fix indexing [UBATCH_IDX]
-            for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
-                cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
-            }
+        for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
+            cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
         }
     }
 
@@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
             seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
         }
     }
+
     // move the head at the end of the slot
     head = head_cur + ubatch.n_tokens;
 }
@@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 }
 
 void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const uint32_t n_tokens     = ubatch->n_tokens;
-    const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
-    const uint32_t n_seqs       = ubatch->n_seqs;
+    const uint32_t n_tokens = ubatch->n_tokens;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
@@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
     for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+        for (uint32_t i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = ubatch->seq_id[i][0];
 
-            for (uint32_t j = 0; j < n_seq_tokens; ++j) {
-                const uint32_t idx = s*n_seq_tokens + j;
+            const llama_pos p1 = ubatch->pos[i];
 
-                const llama_pos p1 = ubatch->pos[idx];
+            for (uint32_t j = 0; j < n_kv; ++j) {
+                float f = 0.0f;
 
-                for (uint32_t i = 0; i < n_kv; ++i) {
-                    float f = 0.0f;
+                bool masked = false;
 
-                    bool masked = false;
-
-                    if (cells.is_empty(i)) {
-                        masked = true;
-                    } else {
-                        const llama_pos p0 = cells.pos_get(i);
-
-                        // mask the token if not the same sequence
-                        masked = masked || (!cells.seq_has(i, seq_id));
+                if (cells.is_empty(j)) {
+                    masked = true;
+                } else {
+                    const llama_pos p0 = cells.pos_get(j);
 
-                        // mask future tokens
-                        masked = masked || (causal_attn && p0 > p1);
+                    // mask the token if not the same sequence
+                    masked = masked || (!cells.seq_has(j, seq_id));
 
-                        // apply SWA if any
-                        masked = masked || (is_masked_swa(p0, p1));
+                    // mask future tokens
+                    masked = masked || (causal_attn && p0 > p1);
 
-                        if (!masked && hparams.use_alibi) {
-                            f = -std::abs(p0 - p1);
-                        }
-                    }
+                    // apply SWA if any
+                    masked = masked || (is_masked_swa(p0, p1));
 
-                    if (masked) {
-                        f = -INFINITY;
+                    if (!masked && hparams.use_alibi) {
+                        f = -std::abs(p0 - p1);
                     }
+                }
 
-                    data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
+                if (masked) {
+                    f = -INFINITY;
                 }
+
+                data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
             }
         }
 
         // mask padded tokens
         if (data) {
-            for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
-                for (uint32_t i = 0; i < n_kv; ++i) {
-                    data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+            for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                for (uint32_t j = 0; j < n_kv; ++j) {
+                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
                 }
             }
         }
@@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
     const int32_t n_kv = dst->ne[0];
 
     for (int h = 0; h < 1; ++h) {
-        for (int j = 0; j < n_tokens; ++j) {
-            for (int i = 0; i < n_kv; ++i) {
+        for (int i = 0; i < n_tokens; ++i) {
+            for (int j = 0; j < n_kv; ++j) {
                 // the position when the cells is empty is irrelevant - it will be masked out later in the attention
-                const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
+                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
 
-                data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
             }
         }
     }
@@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
 
         seq_rm(dest_seq_id, -1, -1);
 
-        llama_sbatch sbatch;
-        llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+        llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
-        ubatch.n_tokens = cell_count;
-        ubatch.n_seq_tokens = cell_count;
-        ubatch.n_seqs = 1;
+        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -1746,9 +1733,8 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
 
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(
         llama_kv_cache_unified * kv,
-        llama_sbatch sbatch,
         llama_kv_cache_unified::ubatch_heads heads,
-        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
 }
 
 llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
@@ -1781,12 +1767,6 @@ bool llama_kv_cache_unified_state::apply() {
     return true;
 }
 
-std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return sbatch.out_ids;
-}
-
 llama_memory_status llama_kv_cache_unified_state::get_status() const {
     return status;
 }
index d96571d952b81db3e47b8653018b79b5e3235325..1560640045c82d9267bc59b472eaf7b25fd98c3f 100644 (file)
@@ -57,7 +57,7 @@ public:
     //
 
     llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
@@ -231,7 +231,6 @@ public:
     // used to create a decode state from a batch
     llama_kv_cache_unified_state(
             llama_kv_cache_unified * kv,
-            llama_sbatch sbatch,
             ubatch_heads heads,
             std::vector<llama_ubatch> ubatches);
 
@@ -244,8 +243,6 @@ public:
     bool next()  override;
     bool apply() override;
 
-    std::vector<int64_t> & out_ids() override;
-
     llama_memory_status  get_status() const override;
     const llama_ubatch & get_ubatch() const override;
 
@@ -286,8 +283,6 @@ private:
     // batch processing state
     //
 
-    llama_sbatch sbatch;
-
     // the index of the next ubatch to process
     size_t i_next = 0;
 
index 1d4e70f4d321249882287e0bf6b1f56f1c8110dc..349e9032e2484b71b55c5ae14664d0d498d3d651 100644 (file)
@@ -384,10 +384,10 @@ private:
     //
     std::vector<llama_pos> shift;
 
-    using bits_t = std::bitset<LLAMA_MAX_SEQ>;
+    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
 
     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
-    std::vector<bits_t> seq;
+    std::vector<seq_set_t> seq;
 
     // the set seq_pos[s] tells us which positions are currently present for sequence s
     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
index d4b260db4c8e7f5aee41cc79060c3939a36677d9..1b16686819eff8582dc7ab6eea4334cb304f69f8 100644 (file)
@@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid(
     mem_attn(new llama_kv_cache_unified(
         model,
         filter_attn == nullptr ?
-            [&](int32_t il) { return !model.hparams.is_recurrent(il); }
+            [&](int32_t il) { return !hparams.is_recurrent(il); }
             : filter_attn,
         type_k,
         type_v,
@@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid(
     mem_recr(new llama_memory_recurrent(
         model,
         filter_recr == nullptr ?
-            [&](int32_t il) { return model.hparams.is_recurrent(il); }
+            [&](int32_t il) { return hparams.is_recurrent(il); }
             : filter_recr,
         type_r,
         type_s,
@@ -56,42 +56,49 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max
     )) {}
 
-llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
+llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+    do {
+        balloc.split_reset();
 
-    // since this includes a recurrent cache, we cannot use split_simple
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+        // follow the recurrent pattern for creating the ubatch splits
+        std::vector<llama_ubatch> ubatches;
 
-    // follow the recurrent pattern for creating the ubatch splits
-    std::vector<llama_ubatch> ubatches;
-    while (sbatch.n_tokens > 0) {
-        llama_ubatch ubatch;
+        while (true) {
+            llama_ubatch ubatch;
 
-        if (embd_pooled) {
-            // Pooled embeddings cannot be split across ubatches (yet)
-            ubatch = sbatch.split_seq(n_ubatch);
-        } else {
-            ubatch = sbatch.split_equal(n_ubatch);
+            if (embd_all) {
+                // if all tokens are output, split by sequence
+                ubatch = balloc.split_seq(n_ubatch);
+            } else {
+                ubatch = balloc.split_equal(n_ubatch);
+            }
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        ubatches.push_back(ubatch);
-    }
+        // prepare the recurrent batches first
+        if (!mem_recr->prepare(ubatches)) {
+            // TODO: will the recurrent cache be in an undefined state at this point?
+            LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
+            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        }
 
-    // prepare the recurrent batches first
-    if (!mem_recr->prepare(ubatches)) {
-        // TODO: will the recurrent cache be in an undefined state at this point?
-        LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
-        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        // prepare the attention cache
+        auto heads_attn = mem_attn->prepare(ubatches);
+        if (heads_attn.empty()) {
+            LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
+            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        }
 
-    // prepare the attention cache
-    auto heads_attn = mem_attn->prepare(ubatches);
-    if (heads_attn.empty()) {
-        LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
-        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        return std::make_unique<llama_memory_hybrid_state>(
+                this, std::move(heads_attn), std::move(ubatches));
+    } while(false);
 
-    return std::make_unique<llama_memory_hybrid_state>(
-        this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
+    return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_state_ptr llama_memory_hybrid::init_full() {
@@ -188,15 +195,13 @@ llama_memory_hybrid_state::llama_memory_hybrid_state(
 
 llama_memory_hybrid_state::llama_memory_hybrid_state(
               llama_memory_hybrid * mem,
-                     llama_sbatch   sbatch,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches) :
-    sbatch(std::move(sbatch)),
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
-    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {},                        this->ubatches)),
-    status(LLAMA_MEMORY_STATUS_SUCCESS) {
+    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
+    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(),                        this->ubatches)),
+    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
 }
 
 bool llama_memory_hybrid_state::next() {
@@ -223,12 +228,6 @@ bool llama_memory_hybrid_state::apply() {
     return res;
 }
 
-std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return sbatch.out_ids;
-}
-
 llama_memory_status llama_memory_hybrid_state::get_status() const {
     return status;
 }
index b5700c5225f18396e591add470461ab8f911fc29..4d27ab896aa05177b8711b7fad6a1e39b13653e6 100644 (file)
@@ -50,9 +50,9 @@ public:
     //
 
     llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
-            bool embd_pooled) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
@@ -107,7 +107,6 @@ public:
     // init success
     llama_memory_hybrid_state(
               llama_memory_hybrid * mem,
-                     llama_sbatch   sbatch,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches);
 
@@ -116,8 +115,6 @@ public:
     bool next()  override;
     bool apply() override;
 
-    std::vector<int64_t> & out_ids() override;
-
     llama_memory_status  get_status() const override;
     const llama_ubatch & get_ubatch() const override;
 
@@ -129,8 +126,6 @@ public:
     const llama_memory_recurrent_state * get_state_recr() const;
 
 private:
-    llama_sbatch sbatch;
-
     // the index of the next ubatch to process
     size_t i_next = 0;
 
index c4f9a6f1ddc981d00976a64aa921ceed1d2df5a8..b064da0084c5295f3a28b93c9dbfb1a15452f6f1 100644 (file)
@@ -362,29 +362,31 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
-
+llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     std::vector<llama_ubatch> ubatches;
 
-    while (sbatch.n_tokens > 0) {
+    while (true) {
         llama_ubatch ubatch;
 
         if (embd_all) {
             // if all tokens are output, split by sequence
-            ubatch = sbatch.split_seq(n_ubatch);
+            ubatch = balloc.split_seq(n_ubatch);
         } else {
-            ubatch = sbatch.split_equal(n_ubatch);
+            ubatch = balloc.split_equal(n_ubatch);
+        }
+
+        if (ubatch.n_tokens == 0) {
+            break;
         }
 
-        ubatches.push_back(ubatch);
+        ubatches.push_back(std::move(ubatch)); // NOLINT
     }
 
     if (!prepare(ubatches)) {
         return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
 }
 
 llama_memory_state_ptr llama_memory_recurrent::init_full() {
@@ -423,9 +425,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
 }
 
 bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
-    const uint32_t n_seqs = ubatch.n_seqs;
-
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+    const uint32_t n_seqs       = ubatch.n_seqs;
 
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
@@ -445,9 +446,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // everything should fit if all seq_ids are smaller than the max
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        const uint32_t n_seq_id = ubatch.n_seq_id[s];
+        const uint32_t i = s*n_seq_tokens; // first token of sequence set s
+        const uint32_t n_seq_id = ubatch.n_seq_id[i];
+
         for (uint32_t j = 0; j < n_seq_id; ++j) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+            const llama_seq_id seq_id = ubatch.seq_id[i][j];
 
             if (seq_id < 0 || (uint32_t) seq_id >= size) {
                 // too big seq_id
@@ -506,7 +509,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // find usable cell range
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        const uint32_t i = s*n_seq_tokens;
+        const llama_seq_id seq_id = ubatch.seq_id[i][0];
         auto & seq_meta = cells[seq_id];
         bool has_cell = false;
         if (seq_meta.tail >= 0) {
@@ -530,7 +534,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
             seq_meta.tail = next_empty_cell;
             // find next empty cell
             if (s + 1 < n_seqs) {
-                for (uint32_t i = 0; i < size; ++i) {
+                for (uint32_t j = 0; j < size; ++j) {
                     next_empty_cell += 1;
                     if (next_empty_cell >= size) { next_empty_cell -= size; }
                     auto & cell = cells[next_empty_cell];
@@ -544,8 +548,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // gather and re-order
     for (uint32_t s = 0; s < n_seqs; ++s) {
+        const uint32_t i = s*n_seq_tokens;
         const int32_t dst_id = s + min;
-        const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
         if (dst_id != src_id) {
             auto & dst_cell = cells[dst_id];
             auto & src_cell = cells[src_id];
@@ -555,8 +560,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
             std::swap(dst_cell.seq_id, src_cell.seq_id);
 
             // swap tails
-            for (uint32_t i = 0; i < size; ++i) {
-                int32_t & tail = cells[i].tail;
+            for (uint32_t j = 0; j < size; ++j) {
+                int32_t & tail = cells[j].tail;
                 if (tail == src_id) {
                     tail = dst_id;
                 } else if (tail == dst_id) {
@@ -568,7 +573,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // update the pos of the used seqs
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+        const uint32_t i = s*n_seq_tokens;
+        const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
         const int32_t cell_id = s + min;
         auto & cell = cells[cell_id];
 
@@ -576,12 +582,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
             // What should happen when the pos backtracks or skips a value?
             // Clearing the state mid-batch would require special-casing which isn't done.
             LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
+                __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
         }
         cell.pos = last_pos;
         cell.seq_id.clear();
-        for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+        for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[i][j];
             cell.seq_id.insert(seq_id);
             cells[seq_id].tail = cell_id;
         }
@@ -827,12 +833,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
 
         seq_rm(dest_seq_id, -1, -1);
 
-        llama_sbatch sbatch;
-        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+        llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
-        batch.n_tokens = cell_count;
-        batch.n_seq_tokens = cell_count;
-        batch.n_seqs = 1;
+        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -846,12 +849,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
                 return false;
             }
 
-            batch.pos[i] = pos;
+            ubatch.pos[i] = pos;
         }
-        batch.n_seq_id[0] = 1;
-        batch.seq_id[0] = &dest_seq_id;
+        ubatch.n_seq_id[0] = 1;
+        ubatch.seq_id[0] = &dest_seq_id;
 
-        if (!find_slot(batch)) {
+        if (!find_slot(ubatch)) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
@@ -859,8 +862,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
         GGML_ASSERT(head + cell_count <= size);
-        GGML_ASSERT(cells[head].pos == batch.pos[0]);
-        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
+        GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
         GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
         GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
     } else {
@@ -1048,8 +1051,7 @@ llama_memory_recurrent_state::llama_memory_recurrent_state(
 
 llama_memory_recurrent_state::llama_memory_recurrent_state(
         llama_memory_recurrent * mem,
-        llama_sbatch sbatch,
-        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
 
 llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
 
@@ -1071,12 +1073,6 @@ bool llama_memory_recurrent_state::apply() {
     return true;
 }
 
-std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return sbatch.out_ids;
-}
-
 llama_memory_status llama_memory_recurrent_state::get_status() const {
     return status;
 }
index 290cc84ab3fbcfcd5c4560315c06dc3be28c5823..be58dae7cfe33184cfe1e08bc9c95c1b9e1ffa49 100644 (file)
@@ -35,7 +35,7 @@ public:
     //
 
     llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
@@ -137,7 +137,6 @@ public:
     // used to create a state from a batch
     llama_memory_recurrent_state(
             llama_memory_recurrent * mem,
-            llama_sbatch sbatch,
             std::vector<llama_ubatch> ubatches);
 
     virtual ~llama_memory_recurrent_state();
@@ -149,8 +148,6 @@ public:
     bool next()  override;
     bool apply() override;
 
-    std::vector<int64_t> & out_ids() override;
-
     llama_memory_status  get_status() const override;
     const llama_ubatch & get_ubatch() const override;
 
@@ -173,8 +170,6 @@ private:
 
     llama_memory_recurrent * mem;
 
-    llama_sbatch sbatch;
-
     size_t i_next = 0;
 
     std::vector<llama_ubatch> ubatches;
index 24668f861b976243bb8bb59fd149e615c71cb87a..d2ef0c2a3b4aafaae17a2bede37b97340f21b64a 100644 (file)
@@ -7,6 +7,8 @@
 
 struct llama_ubatch;
 
+class llama_batch_allocr;
+
 class llama_io_write_i;
 class llama_io_read_i;
 
@@ -50,9 +52,6 @@ struct llama_memory_state_i {
     // return false on failure
     virtual bool apply() = 0;
 
-    // TODO: this might get reworked in the future when refactoring llama_batch
-    virtual std::vector<int64_t> & out_ids() = 0;
-
     // get the current ubatch
     virtual const llama_ubatch & get_ubatch() const = 0;
 
@@ -71,7 +70,7 @@ struct llama_memory_i {
     // return a state object containing the ubatches and KV cache state required to process them
     // check the llama_memory_state_i::get_status() for the result
     virtual llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) = 0;
 
index 9d55b3338bcfe07863a00695173278e50dc1ede5..aa18513e393b467b8065d7af52d2d14dcf51a696 100644 (file)
@@ -3385,38 +3385,6 @@ struct server_context {
             llama_set_embeddings(ctx, slot_batched->need_embd());
         }
 
-        // pad the batch so that batch.n_tokens >= n_slots
-        // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
-        if (slot_batched->need_embd()) {
-            const int n_slots = slots.size();
-
-            if (batch.n_tokens < n_slots) {
-                std::set<llama_seq_id> seq_ids;
-                for (int j = 0; j < batch.n_tokens; ++j) {
-                    seq_ids.insert(batch.seq_id[j][0]);
-                }
-
-                // find unused sequence id
-                llama_seq_id seq_id = -1;
-                for (int i = 0; i < n_slots; ++i) {
-                    if (seq_ids.find(i) == seq_ids.end()) {
-                        seq_id = i;
-                    }
-                }
-
-                const int n_add = n_slots - batch.n_tokens;
-
-                SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
-
-                for (int j = 0; j < n_add; ++j) {
-                    common_batch_add(batch, 0, j, { seq_id }, true);
-                }
-
-                slots[seq_id].cache_tokens.clear();
-                llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1);
-            }
-        }
-
         int32_t i_next = 0;
 
         // process the created batch of tokens