]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
recurrent : call balloc split_reset() in init_batch() (#14414)
authorGeorgi Gerganov <redacted>
Fri, 27 Jun 2025 14:55:45 +0000 (17:55 +0300)
committerGitHub <redacted>
Fri, 27 Jun 2025 14:55:45 +0000 (17:55 +0300)
ggml-ci

src/llama-memory-recurrent.cpp

index 1b1e95d567a6cf9bec8284ffe40fa2863462a82e..e52156bf308b66b417ff9743e216dca93660fab2 100644 (file)
@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
 }
 
 llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
-    std::vector<llama_ubatch> ubatches;
+    do {
+        balloc.split_reset();
 
-    while (true) {
-        llama_ubatch ubatch;
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            llama_ubatch ubatch;
 
-        if (embd_all) {
-            // if all tokens are output, split by sequence
-            ubatch = balloc.split_seq(n_ubatch);
-        } else {
-            ubatch = balloc.split_equal(n_ubatch);
+            if (embd_all) {
+                // if all tokens are output, split by sequence
+                ubatch = balloc.split_seq(n_ubatch);
+            } else {
+                ubatch = balloc.split_equal(n_ubatch);
+            }
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        if (ubatch.n_tokens == 0) {
+        if (!prepare(ubatches)) {
             break;
         }
 
-        ubatches.push_back(std::move(ubatch)); // NOLINT
-    }
-
-    if (!prepare(ubatches)) {
-        return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+    } while (false);
 
-    return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_context_ptr llama_memory_recurrent::init_full() {