std::string grammar_kind;
std::string grammar_data;
LlgTokenizer * tokenizer;
- LlgConstraint * grammar;
- LlgMaskResult llg_res;
- bool has_llg_res;
+ LlgMatcher * grammar;
};
-static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
- const char * grammar_data) {
+static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
+ const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
if (log_level && *log_level) {
cinit.log_stderr_level = atoi(log_level);
}
- auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
- if (llg_get_error(c)) {
- LOG_ERR("llg error: %s\n", llg_get_error(c));
- llg_free_constraint(c);
+ auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
+ if (llg_matcher_get_error(c)) {
+ LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
+ llg_free_matcher(c);
return nullptr;
}
+
return c;
}
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
- LlgCommitResult res;
- llg_commit_token(ctx->grammar, token, &res);
- ctx->has_llg_res = false;
+ llg_matcher_consume_token(ctx->grammar, token);
}
}
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
- if (!ctx->has_llg_res) {
- if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
- ctx->has_llg_res = true;
+ const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
+ if (mask == nullptr) {
+ if (llg_matcher_compute_mask(ctx->grammar) == 0) {
+ mask = llg_matcher_get_mask(ctx->grammar);
} else {
- LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
- llg_free_constraint(ctx->grammar);
+ LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
+ llg_free_matcher(ctx->grammar);
ctx->grammar = nullptr;
+ return;
}
}
- if (ctx->has_llg_res) {
- if (ctx->llg_res.is_stop) {
- for (size_t i = 0; i < cur_p->size; ++i) {
- if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
- cur_p->data[i].logit = -INFINITY;
- }
- }
- } else {
- const uint32_t * mask = ctx->llg_res.sample_mask;
- for (size_t i = 0; i < cur_p->size; ++i) {
- auto token = cur_p->data[i].id;
- if ((mask[token / 32] & (1 << (token % 32))) == 0) {
- cur_p->data[i].logit = -INFINITY;
- }
- }
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ auto token = cur_p->data[i].id;
+ if ((mask[token / 32] & (1 << (token % 32))) == 0) {
+ cur_p->data[i].logit = -INFINITY;
}
}
}
static void llama_sampler_llg_reset(llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
- if (!ctx->grammar) {
- return;
+ if (ctx->grammar) {
+ llg_matcher_reset(ctx->grammar);
}
-
- auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
- llg_free_constraint(ctx->grammar);
- ctx->grammar = grammar_new;
- ctx->has_llg_res = false;
}
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
- result_ctx->grammar = llg_clone_constraint(ctx->grammar);
+ result_ctx->grammar = llg_clone_matcher(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
- llg_free_constraint(ctx->grammar);
+ llg_free_matcher(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
- /* .llg_res = */ {},
- /* .has_llg_res = */ false,
};
+ if (ctx->grammar) {
+ GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
+ llg_matcher_get_mask_byte_size(ctx->grammar));
+ }
} else {
*ctx = {
/* .vocab = */ vocab,
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
- /* .llg_res = */ {},
- /* .has_llg_res = */ false,
};
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
- /* .ctx = */ ctx
- );
+ /* .ctx = */ ctx);
}
#else
});
}
+static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
+ auto n_vocab = tok_arr.size;
+
+ tok_arr.selected = -1;
+ tok_arr.sorted = false;
+ for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+ tok_arr.data[token_id].id = token_id;
+ tok_arr.data[token_id].logit = 0.0f;
+ }
+
+ tok_arr.data[selected].logit = 100.0f;
+}
+
+static void test_sampler_chain(void) {
+ auto sparams = llama_sampler_chain_default_params();
+ sparams.no_perf = false;
+ llama_sampler * sampler = llama_sampler_chain_init(sparams);
+
+ const auto grammar_data = R"(%llguidance {}
+start: /[A-Z ]*/)";
+
+ llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
+ llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
+
+ auto input = "ALL YOUR BASE ARE BELONG TO US";
+ auto tokens = common_tokenize(vocab, input, false, false);
+
+ auto n_vocab = llama_vocab_n_tokens(vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+ cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+ }
+ auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+ for (const auto token : tokens) {
+ one_hot(tok_arr, token);
+
+ fprintf(stderr, "applying token: %d\n", token);
+ llama_sampler_apply(sampler, &tok_arr);
+
+ auto idx = tok_arr.selected;
+ fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
+ assert(cur[tok_arr.selected].id == token);
+ llama_sampler_accept(sampler, token);
+ }
+
+ auto tok_eos = llama_vocab_eot(vocab);
+ if (tok_eos == LLAMA_TOKEN_NULL) {
+ tok_eos = llama_vocab_eos(vocab);
+ }
+
+ one_hot(tok_arr, tok_eos);
+
+ llama_sampler_apply(sampler, &tok_arr);
+ assert(cur[tok_arr.selected].id == tok_eos);
+}
+
int main(int argc, const char ** argv) {
fprintf(stdout, "Running llguidance integration tests...\n");
test_special_chars();
test_quantifiers();
test_json_schema();
+
+ test_sampler_chain();
+
fprintf(stdout, "All tests passed.\n");
return 0;
}