delete smpl;
}
-llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
- const auto * logits = llama_get_logits_ith(ctx, idx);
-
- const llama_model * model = llama_get_model(ctx);
- const llama_vocab * vocab = llama_model_get_vocab(model);
-
- const int n_vocab = llama_vocab_n_tokens(vocab);
-
- // TODO: do not allocate each time
- std::vector<llama_token_data> cur;
- cur.reserve(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
-
- llama_token_data_array cur_p = {
- /* .data = */ cur.data(),
- /* .size = */ cur.size(),
- /* .selected = */ -1,
- /* .sorted = */ false,
- };
-
- llama_sampler_apply(smpl, &cur_p);
-
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
-
- auto token = cur_p.data[cur_p.selected].id;
-
- llama_sampler_accept(smpl, token);
-
- return token;
-}
-
// sampler chain
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
+ /* .cur = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
}
);
}
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+ const auto * logits = llama_get_logits_ith(ctx, idx);
+
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+
+ // use pre-allocated buffer from chain if available, otherwise allocate locally
+ std::vector<llama_token_data> * cur_ptr;
+ std::vector<llama_token_data> cur_local;
+
+ if (smpl->iface == &llama_sampler_chain_i) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+ cur_ptr = &chain->cur;
+ } else {
+ cur_ptr = &cur_local;
+ }
+
+ auto & cur = *cur_ptr;
+ cur.resize(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ }
+
+ llama_token_data_array cur_p = {
+ /* .data = */ cur.data(),
+ /* .size = */ cur.size(),
+ /* .selected = */ -1,
+ /* .sorted = */ false,
+ };
+
+ llama_sampler_apply(smpl, &cur_p);
+
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+ auto token = cur_p.data[cur_p.selected].id;
+
+ llama_sampler_accept(smpl, token);
+
+ return token;
+}
+
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
auto * p = (llama_sampler_chain *) chain->ctx;
p->samplers.push_back(smpl);