struct common_sampler {
common_params_sampling params;
+ struct llama_sampler * grmr;
struct llama_sampler * chain;
- bool grammar;
-
ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur;
lparams.no_perf = params.no_perf;
+ llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
- bool grammar = false;
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
- samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
- grammar = true;
+ grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
- samplers.push_back(
- llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
- trigger_patterns_c.data(), trigger_patterns_c.size(),
- trigger_tokens.data(), trigger_tokens.size()));
+ grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
+ trigger_patterns_c.data(), trigger_patterns_c.size(),
+ trigger_tokens.data(), trigger_tokens.size());
} else {
- samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
+ grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
-
- grammar = true;
}
}
auto * result = new common_sampler {
/* .params = */ params,
+ /* .grmr = */ grmr,
/* .chain = */ chain,
- /* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
+ llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
- if (gsmpl->grammar) {
- const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
-
- for (int i = 0; i < n_smpl; i++) {
- auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
-
- // the grammar sampler is always the first one
- if (i == 0) {
- if (accept_grammar) {
- llama_sampler_accept(smpl, token);
- }
- } else {
- llama_sampler_accept(smpl, token);
- }
- }
- } else {
- llama_sampler_accept(gsmpl->chain, token);
+ if (gsmpl->grmr && accept_grammar) {
+ llama_sampler_accept(gsmpl->grmr, token);
}
+ llama_sampler_accept(gsmpl->chain, token);
+
gsmpl->prev.push_back(token);
}
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
- /* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
return gsmpl->chain;
}
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
llama_token id = LLAMA_TOKEN_NULL;
+ auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx);
+ if (grammar_first) {
+ llama_sampler_apply(grmr, &cur_p);
+ }
+
+ llama_sampler_apply(chain, &cur_p);
+
+ id = cur_p.data[cur_p.selected].id;
+
+ if (grammar_first) {
+ return id;
+ }
+
+ // check if it the sampled token fits the grammar (grammar-based rejection sampling)
+ {
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
+
+ llama_sampler_apply(grmr, &single_token_data_array);
+
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+ if (is_valid) {
+ return id;
+ }
+ }
+
+ // resampling:
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+ gsmpl->set_logits(ctx, idx);
+
+ llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
return id;
}
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
size_t i = 0;
for (; i < draft.size(); i++) {
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
}
if (i == draft.size()) {
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
return result;
}
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
- return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample
//
//
// returns at least 1 token, up to idxs.size()
//
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);