]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : rename batch_all to batch (#8881)
authorDaniel Bevenius <redacted>
Thu, 17 Oct 2024 23:41:51 +0000 (01:41 +0200)
committerGitHub <redacted>
Thu, 17 Oct 2024 23:41:51 +0000 (01:41 +0200)
This commit addresses the TODO in the code to rename the `batch_all`
parameter to `batch` in `llama_decode_internal`.

src/llama.cpp

index ffaa6f789b310fafc8438c1b539667e893e76efa..dcb015d12544be8083a19f11db35adc78a664b25 100644 (file)
@@ -17134,10 +17134,10 @@ static void llama_graph_compute(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch_all) { // TODO: rename back to batch
+           llama_batch   batch) {
 
     lctx.is_encoding = false;
-    const uint32_t n_tokens_all = batch_all.n_tokens;
+    const uint32_t n_tokens_all = batch.n_tokens;
 
     if (n_tokens_all == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17148,12 +17148,12 @@ static int llama_decode_internal(
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-    if (batch_all.token) {
+    if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all.token[i]);
+            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
         }
@@ -17184,9 +17184,9 @@ static int llama_decode_internal(
     lctx.embd_seq.clear();
 
     // count outputs
-    if (batch_all.logits && !embd_pooled) {
+    if (batch.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch_all.logits[i] != 0;
+            n_outputs += batch.logits[i] != 0;
         }
     } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
@@ -17195,7 +17195,7 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
-    lctx.sbatch.from_batch(batch_all, n_embd,
+    lctx.sbatch.from_batch(batch, n_embd,
         /* simple_split */ !kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);