]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : fix output reorder with backend sampling (#19638)
authorGeorgi Gerganov <redacted>
Sun, 15 Feb 2026 12:57:40 +0000 (14:57 +0200)
committerGitHub <redacted>
Sun, 15 Feb 2026 12:57:40 +0000 (14:57 +0200)
src/llama-context.cpp
src/llama-context.h

index ac17e1a0fe2d99bcf8616874172878d6ba305342..99035b6caced81c35cbb9c45554219812be0bb1a 100644 (file)
@@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
         }
     } catch (const std::exception & err) {
         // fallback to full vocab list
+        GGML_UNUSED(err);
     }
 
     return sampling.token_ids_full_vocab.data();
@@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
 //
 
 uint32_t llama_context::output_reserve(int32_t n_outputs) {
-
     const auto & hparams = model.hparams;
     const auto & vocab   = model.vocab;
 
@@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
     offset += embd.size * sizeof(float);
 
-    sampling.logits     = {nullptr, 0};
-    sampling.probs      = {nullptr, 0};
-    sampling.sampled    = {nullptr, 0};
-    sampling.candidates = {nullptr, 0};
-
     if (has_sampling) {
         sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
         offset += sampling.logits.size * sizeof(float);
@@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
         std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
 
         std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
+    } else {
+        sampling.logits     = {nullptr, 0};
+        sampling.probs      = {nullptr, 0};
+        sampling.sampled    = {nullptr, 0};
+        sampling.candidates = {nullptr, 0};
+
+        sampling.logits_count.clear();
+        sampling.probs_count.clear();
+        sampling.candidates_count.clear();
     }
 
     // set all ids as invalid (negative)
@@ -1953,37 +1957,30 @@ void llama_context::output_reorder() {
             }
         }
 
-        if (sampling.logits.has_data()) {
+        if (!sampling.samplers.empty()) {
+            assert(sampling.logits.size > 0);
+            assert(sampling.probs.size > 0);
+            assert(sampling.candidates.size > 0);
+            assert(sampling.sampled.size > 0);
+            assert(sampling.logits_count.size() > 0);
+            assert(sampling.probs_count.size() > 0);
+            assert(sampling.candidates_count.size() > 0);
+
             for (uint64_t k = 0; k < n_vocab; ++k) {
                 std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
             }
-        }
 
-        if (sampling.probs.has_data()) {
             for (uint64_t k = 0; k < n_vocab; ++k) {
                 std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
             }
-        }
 
-        if (sampling.candidates.has_data()) {
             for (uint64_t k = 0; k < n_vocab; ++k) {
                 std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
             }
-        }
-
-        if (sampling.sampled.has_data()) {
-            std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
-        }
-
-        if (!sampling.logits_count.empty()) {
-            std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
-        }
-
-        if (!sampling.probs_count.empty()) {
-            std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
-        }
 
-        if (!sampling.candidates_count.empty()) {
+            std::swap(sampling.sampled.data[i0],     sampling.sampled.data[i1]);
+            std::swap(sampling.logits_count[i0],     sampling.logits_count[i1]);
+            std::swap(sampling.probs_count[i0],      sampling.probs_count[i1]);
             std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
         }
     }
index 37117ba7b67563857600bcd88dc2e349df846c81..a8e53f335cc63ded1cd6e24e69d1d0dbe53bdc1c 100644 (file)
@@ -265,24 +265,26 @@ private:
     std::unique_ptr<llama_memory_i> memory;
 
     // decode output (2-dimensional array: [n_outputs][n_vocab])
-    struct buffer_view<float>  logits = {nullptr, 0};
+    buffer_view<float> logits = {nullptr, 0};
 
     // embeddings output (2-dimensional array: [n_outputs][n_embd])
     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
-    struct buffer_view<float>  embd = {nullptr, 0};
+    buffer_view<float> embd = {nullptr, 0};
 
     struct sampling_info {
+        // !samplers.empty() to check if any samplers are active
         std::map<llama_seq_id, llama_sampler *> samplers;
 
-        struct buffer_view<float>       logits     = {nullptr, 0};
-        struct buffer_view<llama_token> sampled    = {nullptr, 0};
-        struct buffer_view<float>       probs      = {nullptr, 0};
-        struct buffer_view<llama_token> candidates = {nullptr, 0};
+        buffer_view<float>       logits     = {nullptr, 0};
+        buffer_view<llama_token> sampled    = {nullptr, 0};
+        buffer_view<float>       probs      = {nullptr, 0};
+        buffer_view<llama_token> candidates = {nullptr, 0};
 
         std::vector<uint32_t> logits_count;
         std::vector<uint32_t> probs_count;
         std::vector<uint32_t> candidates_count;
 
+        // optimization
         std::vector<llama_token> token_ids_full_vocab;
     };