]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batch : add n_used count (#14512)
authorGeorgi Gerganov <redacted>
Fri, 4 Jul 2025 06:04:59 +0000 (09:04 +0300)
committerGitHub <redacted>
Fri, 4 Jul 2025 06:04:59 +0000 (09:04 +0300)
ggml-ci

src/llama-batch.cpp
src/llama-batch.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified.cpp
src/llama-memory-hybrid.cpp
src/llama-memory-recurrent.cpp

index 91b1d6078a2529e4c31c43a5295f26c35b70e090..8d84c68053d327860a1ff1a15964d9d4048b5c03 100644 (file)
@@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
     return n_outputs;
 }
 
+uint32_t llama_batch_allocr::get_n_used() const {
+    return n_used;
+}
+
 std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
     return out_ids;
 }
@@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
 void llama_batch_allocr::split_reset() {
     out_ids.clear();
 
+    n_used = 0;
+
     used.clear();
     used.resize(get_n_tokens(), false);
 
@@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
         idxs.push_back(cur_idx);
 
         used[cur_idx] = true;
+        ++n_used;
 
         ++cur_idx;
 
@@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
             idxs_per_seq[s].push_back(idx);
 
             used[idx] = true;
+            ++n_used;
 
             ++cur_idx[s];
         }
@@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
         idxs.push_back(cur_idx);
 
         used[cur_idx] = true;
+        ++n_used;
 
         if (idxs.size() >= n_ubatch) {
             break;
index d2c5376188a0bd985ef7039cea5231e047198b91..edff8cdd63cfcab0b2fa4385470f9066bd3e36d5 100644 (file)
@@ -54,6 +54,7 @@ public:
 
     uint32_t get_n_tokens()  const;
     uint32_t get_n_outputs() const;
+    uint32_t get_n_used()    const;
 
     // the array of output indices in the order they were encountered during the ubatch splitting
     std::vector<int32_t> & get_out_ids();
@@ -125,6 +126,8 @@ private:
     // batch indices of the output
     std::vector<int32_t> out_ids;
 
+    uint32_t n_used;
+
     // used[i] indicates if token i has already been used in a previous ubatch
     std::vector<bool> used;
 
index ee202cc710bd65849f08dde43385d127e9028117..ab4c41c780397afee783cdfb7d55f738457b85e3 100644 (file)
@@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
         auto sinfos_base = kv_base->prepare(ubatches);
         if (sinfos_base.empty()) {
             break;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
         auto sinfos_base = kv_base->prepare(ubatches);
         if (sinfos_base.empty()) {
             break;
index ff22079851b2a44964671b9638f9b3f467bc3ca5..d3129cc53281e6589ebd7e7a3cae1ea6407878c4 100644 (file)
@@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
         auto sinfos = prepare(ubatches);
         if (sinfos.empty()) {
             break;
index 03d974d852039be21960347eccd0941d785b1c15..908e927faa2aacfda2ad26f095fec0bbad0337aa 100644 (file)
@@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
         // prepare the recurrent batches first
         if (!mem_recr->prepare(ubatches)) {
             // TODO: will the recurrent cache be in an undefined context at this point?
index 6ed84057ccfe25b36006faee916fbcd7218e95c9..ca0c8dd56f2fa73a9b3351fdf675a3c8cff62a94 100644 (file)
@@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
                 ubatch = balloc.split_equal(n_ubatch);
             }
 
-            if (ubatch.n_tokens == 0) {
+            if (balloc.get_n_used() < balloc.get_n_tokens()) {
+                // failed to find a suitable split
                 break;
             }