]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batch : add optional for sequential equal split (#14511)
authorGeorgi Gerganov <redacted>
Fri, 4 Jul 2025 06:08:59 +0000 (09:08 +0300)
committerGitHub <redacted>
Fri, 4 Jul 2025 06:08:59 +0000 (09:08 +0300)
ggml-ci

src/llama-batch.cpp
src/llama-batch.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-memory-hybrid.cpp
src/llama-memory-recurrent.cpp

index 8d84c68053d327860a1ff1a15964d9d4048b5c03..3bc8554e51ccf518e781ba5076780ae757c294a9 100644 (file)
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
 
                 // note: tracking the other way around is not necessary for now
                 //seq_cpl[s0][s1] = true;
+
+                has_cpl = true;
             }
         }
     }
@@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
     return ubatch_add(idxs, idxs.size(), false);
 }
 
-llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
+llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
+    if (sequential && has_cpl) {
+        LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
+
+        return {};
+    }
+
     std::vector<seq_set_t> cur_seq_set;
 
+    llama_seq_id last_seq_id = -1;
+
     // determine the non-overlapping sequence sets participating in this ubatch
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
         if (used[i]) {
@@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
             }
         }
 
+        // accept only increasing sequence ids
+        if (sequential) {
+            add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
+        }
+
         if (add) {
             cur_seq_set.push_back(seq_set[i]);
 
+            last_seq_id = batch.seq_id[i][0];
+
             if (cur_seq_set.size() > n_ubatch) {
                 break;
             }
index edff8cdd63cfcab0b2fa4385470f9066bd3e36d5..3420803ff946967319e96947497b315ae74d1f6c 100644 (file)
@@ -70,7 +70,8 @@ public:
     llama_ubatch split_simple(uint32_t n_ubatch);
 
     // make ubatches of equal-length sequences sets
-    llama_ubatch split_equal(uint32_t n_ubatch);
+    // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
+    llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
 
     // sequence-set-wise split - each ubatch contains a single sequence-set
     llama_ubatch split_seq(uint32_t n_ubatch);
@@ -113,6 +114,9 @@ private:
     using pos_set_t = std::set<llama_pos>;
     using seq_cpl_t = std::vector<bool>;
 
+    // helper flag to quickly determine if there are any coupled sequences in the batch
+    bool has_cpl;
+
     std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
     std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 
index ab4c41c780397afee783cdfb7d55f738457b85e3..fe207ad536032d34930bc6f4bbfc4a35b4226bb5 100644 (file)
@@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch);
+            auto ubatch = balloc.split_equal(n_ubatch, false);
 
             if (ubatch.n_tokens == 0) {
                 break;
index 908e927faa2aacfda2ad26f095fec0bbad0337aa..6cd10db06b77571675aa2a6dfa6ab8977552d199 100644 (file)
@@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch);
+                ubatch = balloc.split_equal(n_ubatch, false);
             }
 
             if (ubatch.n_tokens == 0) {
index ca0c8dd56f2fa73a9b3351fdf675a3c8cff62a94..4b90dac7a327cf4a0c29396b3a3215a1e2ad84d0 100644 (file)
@@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch);
+                ubatch = balloc.split_equal(n_ubatch, false);
             }
 
             if (balloc.get_n_used() < balloc.get_n_tokens()) {