]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling: reuse token data buffer in llama_sampler_sample (#18365)
authorJay Zenith <redacted>
Tue, 30 Dec 2025 14:27:49 +0000 (06:27 -0800)
committerGitHub <redacted>
Tue, 30 Dec 2025 14:27:49 +0000 (16:27 +0200)
* sampling: reuse token data buffer in llama_sampler_sample

* move cur buffer before timing section, after samplers

* minor : fix build

---------

Co-authored-by: Georgi Gerganov <redacted>
src/llama-sampling.cpp
src/llama-sampling.h

index d96f619ae1e0d96d95c0ed6cc3ad63b0d6e3d6e4..f3891453e4b8a2b3f2a34cf7c227c36d6acf9549 100644 (file)
@@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) {
     delete smpl;
 }
 
-llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
-    const auto * logits = llama_get_logits_ith(ctx, idx);
-
-    const llama_model * model = llama_get_model(ctx);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
-    const int n_vocab = llama_vocab_n_tokens(vocab);
-
-    // TODO: do not allocate each time
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-    }
-
-    llama_token_data_array cur_p = {
-        /* .data       = */ cur.data(),
-        /* .size       = */ cur.size(),
-        /* .selected   = */ -1,
-        /* .sorted     = */ false,
-    };
-
-    llama_sampler_apply(smpl, &cur_p);
-
-    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
-
-    auto token = cur_p.data[cur_p.selected].id;
-
-    llama_sampler_accept(smpl, token);
-
-    return token;
-}
-
 // sampler chain
 
 static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
             /* .samplers    = */ {},
+            /* .cur         = */ {},
             /* .t_sample_us = */ 0,
             /* .n_sample    = */ 0,
         }
     );
 }
 
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+    const auto * logits = llama_get_logits_ith(ctx, idx);
+
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
+
+    // use pre-allocated buffer from chain if available, otherwise allocate locally
+    std::vector<llama_token_data> * cur_ptr;
+    std::vector<llama_token_data> cur_local;
+
+    if (smpl->iface == &llama_sampler_chain_i) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+        cur_ptr = &chain->cur;
+    } else {
+        cur_ptr = &cur_local;
+    }
+
+    auto & cur = *cur_ptr;
+    cur.resize(n_vocab);
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+    }
+
+    llama_token_data_array cur_p = {
+        /* .data       = */ cur.data(),
+        /* .size       = */ cur.size(),
+        /* .selected   = */ -1,
+        /* .sorted     = */ false,
+    };
+
+    llama_sampler_apply(smpl, &cur_p);
+
+    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+    auto token = cur_p.data[cur_p.selected].id;
+
+    llama_sampler_accept(smpl, token);
+
+    return token;
+}
+
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
     auto * p = (llama_sampler_chain *) chain->ctx;
     p->samplers.push_back(smpl);
index 759dd7dcb7042e182013a34a232769e534e70220..1e3de4e2ec4988ae4d5b41cc420099ca928c7810 100644 (file)
@@ -16,6 +16,9 @@ struct llama_sampler_chain {
 
     std::vector<struct llama_sampler *> samplers;
 
+    // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations
+    std::vector<llama_token_data> cur;
+
     // timing
 
     mutable int64_t t_sample_us;