]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : simplify output counting logic during decode (#14142)
authorGeorgi Gerganov <redacted>
Thu, 12 Jun 2025 08:50:01 +0000 (11:50 +0300)
committerGitHub <redacted>
Thu, 12 Jun 2025 08:50:01 +0000 (11:50 +0300)
* batch : remove logits_all flag

ggml-ci

* context : simplify output counting logic during decode

ggml-ci

* cont : fix comments

src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp

index 58787fdba0d4408fd49644b9ba538e37fc30745b..69e0d7549c334795a7511d6c82290cbdbaf3fa8d 100644 (file)
@@ -306,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
         batch.seq_id = seq_id.data();
     }
     if (!batch.logits) {
-        logits.resize(batch.n_tokens);
-        logits[logits.size() - 1] = true;
-        batch.logits = logits.data();
+        // by default return the output only for the last token
+        output.resize(batch.n_tokens);
+        output[output.size() - 1] = true;
+        batch.logits = output.data();
     }
 }
 
index 989fb6cf9d95c742fef9ddafb717da0816d7bbc1..7ad82b528b18b8b821912c00de5b4d49056ac505 100644 (file)
@@ -85,7 +85,7 @@ struct llama_batch_allocr {
     std::vector<llama_pos>      pos;
     std::vector<int32_t>        n_seq_id;
     std::vector<llama_seq_id *> seq_id;
-    std::vector<int8_t>         logits;
+    std::vector<int8_t>         output;
 
     // optionally fulfill the batch returned by llama_batch_get_one
     llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
index ebcba6993c471bfc1659c16a8ba36dec268e459d..2e551bf6e111c907e8b649e67863143b8888e67b 100644 (file)
@@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
         t_compute_start_us = ggml_time_us();
     }
 
+    // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
 
     n_queued_tokens += n_tokens;
@@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
     }
 
+    // this indicates we are doing pooled embedding
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    int64_t n_outputs_all = 0;
+
+    // count outputs
+    for (uint32_t i = 0; i < n_tokens_all; ++i) {
+        n_outputs_all += batch.logits[i] != 0;
+    }
+
+    if (embd_pooled) {
+        // require that all tokens are output
+        if (n_outputs_all != n_tokens_all) {
+            LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
+                    __func__, n_outputs_all, n_tokens_all);
+            return -1;
+        }
+    }
+
     GGML_ASSERT(n_tokens_all <= cparams.n_batch);
 
     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
@@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
     }
     n_queued_tokens += n_tokens_all;
 
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
+    // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
 
-    int64_t n_outputs_all = 0;
-
-    // count outputs
-    if (batch.logits && !embd_pooled) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs_all += batch.logits[i] != 0;
-        }
-    } else if (embd_pooled) {
-        n_outputs_all = n_tokens_all;
-    } else {
-        // keep last output only
-        n_outputs_all = 1;
-    }
-
     bool did_optimize = false;
 
     // handle any pending defrags/shifts
@@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     do {
         const auto & ubatch = mstate->get_ubatch();
 
-        // count the outputs in this u_batch
+        // count the outputs in this ubatch
         {
             int32_t n_outputs_new = 0;
 
@@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
 
         n_queued_tokens += n_tokens_all;
 
-        // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+        // this indicates we are doing pooled embedding
         const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
         embd_seq.clear();