]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batch : remove logits_all flag (#14141)
authorGeorgi Gerganov <redacted>
Thu, 12 Jun 2025 08:49:26 +0000 (11:49 +0300)
committerGitHub <redacted>
Thu, 12 Jun 2025 08:49:26 +0000 (11:49 +0300)
ggml-ci

src/llama-batch.cpp
src/llama-batch.h
src/llama-context.cpp
src/llama-kv-cache-recurrent.cpp
src/llama-kv-cache-recurrent.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory.h

index 6a19a243118d344bfd9f33a881a356dc74929138..58787fdba0d4408fd49644b9ba538e37fc30745b 100644 (file)
@@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
             ubatch.seq_id = batch->seq_id + seq.offset;
         }
     }
-    if (logits_all) {
-        for (size_t i = 0; i < length; ++i) {
-            ubatch.output[ubatch.n_tokens + i] = 1;
-            out_ids.push_back(ids[seq.offset + i]);
-        }
-    } else if (batch->logits) {
+    if (batch->logits) {
         if (ubatch.equal_seqs) {
             for (size_t i = 0; i < length; ++i) {
                 size_t id = ids[seq.offset + i];
@@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
     return ubatch;
 }
 
-llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
+llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
     GGML_ASSERT(batch.n_tokens >= 0);
     this->batch = &batch;
     this->n_embd = n_embd;
-    this->logits_all = logits_all;
 
     n_tokens = batch.n_tokens;
     ids.resize(n_tokens);
index b8260b94fd2d0aaf301347fdf70af299251556ef..989fb6cf9d95c742fef9ddafb717da0816d7bbc1 100644 (file)
@@ -39,8 +39,6 @@ struct llama_sbatch {
 
     size_t n_embd;
 
-    bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
     // sorted indices into the batch
     std::vector<int64_t> ids;
     // batch indices of the output
@@ -76,7 +74,7 @@ struct llama_sbatch {
     llama_ubatch split_seq(size_t n_ubatch);
 
     llama_sbatch() = default;
-    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
+    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
 };
 
 // temporary allocate memory for the input batch if needed
index 8cea21d6989efd5a1b7aad0486ac03a35bcf5b57..ebcba6993c471bfc1659c16a8ba36dec268e459d 100644 (file)
@@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     const int64_t n_embd = hparams.n_embd;
 
-    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
@@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     llama_memory_state_ptr mstate;
 
     while (true) {
-        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
         if (!mstate) {
             return -2;
         }
@@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter(
 
         int64_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
index f8cdd52808d7be521c71e9bbbc8f41fdf7128db6..de23b4ad23bcec457538630a4ff4340b78d8df99 100644 (file)
@@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
     std::vector<llama_ubatch> ubatches;
 
index 4b33bafd71cca510374ae55f3993100b61fe21c4..d7c02ea8721609150e3a346d9b9df688d9883d0a 100644 (file)
@@ -32,8 +32,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 
index caa58ea9aa3b0e6274c264fcca3cc309bc2383a1..9814f766312036399c51623275f9cb81923c48a0 100644 (file)
@@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
     // first try simple split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
         std::vector<llama_ubatch> ubatches;
 
@@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
 
     // if it fails, try equal split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
         std::vector<llama_ubatch> ubatches;
 
index 3dbf33ed7b960d3985804e84a63eccddbb308198..d114c7378fbe94f4c3b57dca729d7fb9ccebc1a0 100644 (file)
@@ -34,8 +34,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 
index ddeb138f38fb960533390ac0153e0362ad5a3a42..89606c598fc4fc612c5ac486a798d590730e2001 100644 (file)
@@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) {
+            bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
         std::vector<llama_ubatch> ubatches;
         while (sbatch.n_tokens > 0) {
index cf4c691babd1ed0c7dfa908e1218fb3513bbbdaa..d6dcd19f2507e04b37525df8d5b1d0ebac87d9e6 100644 (file)
@@ -59,8 +59,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 
index 991aae781ba57003d2d975994f848ff69735a931..42e226dc0ed61ebf3a8d4a45a46211392fc79a53 100644 (file)
@@ -73,8 +73,7 @@ struct llama_memory_i {
     virtual llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) = 0;
+            bool embd_pooled) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_state_ptr init_full() = 0;