]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
completion : simplify batch (embd) processing (#19286)
authorDaniel Bevenius <redacted>
Wed, 4 Feb 2026 04:43:28 +0000 (05:43 +0100)
committerGitHub <redacted>
Wed, 4 Feb 2026 04:43:28 +0000 (05:43 +0100)
* completion : simplify batch (embd) processing

This commit simplifies the processing of embd by removing the for loop
that currently exists which uses params.n_batch as its increment. This
commit also removes the clamping of n_eval as the size of embd is always
at most the size of params.n_batch.

The motivation is to clarify the code as it is currently a little
confusing when looking at this for loop in isolation and thinking that
it can process multiple batches.

* add an assert to verify n_eval is not greater than n_batch

tools/completion/completion.cpp

index f368a2f4c65c8a90524fbebb61190d086bd0803d..977132756f7f9bb084782a9665610832fd2f06d8 100644 (file)
@@ -674,15 +674,12 @@ int main(int argc, char ** argv) {
                 }
             }
 
-            for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
-                int n_eval = (int) embd.size() - i;
-                if (n_eval > params.n_batch) {
-                    n_eval = params.n_batch;
-                }
-
+            if (!embd.empty()) {
+                int n_eval = (int) embd.size();
                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
 
-                if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
+                GGML_ASSERT(n_eval <= params.n_batch);
+                if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
                     LOG_ERR("%s : failed to eval\n", __func__);
                     return 1;
                 }
@@ -743,7 +740,7 @@ int main(int argc, char ** argv) {
                 common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
 
                 ++n_consumed;
-                if ((int) embd.size() >= params.n_batch) {
+                if ((int) embd.size() == params.n_batch) {
                     break;
                 }
             }