static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
const auto cpts = unicode_cpts_from_utf8(input_str);
- const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
- llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+ auto & stacks_cur = llama_grammar_get_stacks(grammar);
size_t pos = 0;
for (const auto & cpt : cpts) {
- const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
-
- llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
+ llama_grammar_accept(grammar, cpt);
if (stacks_cur.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
- stacks_cur = stacks_prev;
return false;
}
++pos;
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
if (grammar == nullptr) {
- throw std::runtime_error("Failed to initialize llama_grammar");
+ fprintf(stdout, "Failed to initialize llama_grammar\n");
+ return 1;
}
// Read the input file
std::string input_str;
return grammar->stacks;
}
-void llama_grammar_accept(
- const llama_grammar_rules & rules,
- const llama_grammar_stacks & stacks,
- const uint32_t chr,
- llama_grammar_stacks & stacks_new) {
- stacks_new.clear();
- stacks_new.reserve(stacks.size());
+void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
+ llama_grammar_stacks stacks_new;
+ stacks_new.reserve(grammar->stacks.size());
- for (const auto & stack : stacks) {
+ for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
continue;
}
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
- llama_grammar_advance_stack(rules, new_stack, stacks_new);
+ llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
}
}
+
+ grammar->stacks = std::move(stacks_new);
}
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
}
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
- llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
+ llama_grammar * result = new llama_grammar {
+ grammar.vocab,
+ grammar.rules,
+ grammar.stacks,
+ grammar.partial_utf8,
+ };
// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
- result->stacks[is][ie] = &result->rules[ir0][ir1];
+ result->stacks[is][ie] = &result->rules[ir0][ir1];
}
}
}
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;
- llama_grammar_stacks stacks_new;
-
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
- grammar.stacks = std::move(stacks_new);
+ llama_grammar_accept(&grammar, *it);
}
grammar.partial_utf8 = decoded.second;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+// TODO: remove, needed for tests atm
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
-void llama_grammar_accept(
- const llama_grammar_rules & rules,
- const llama_grammar_stacks & stacks,
- uint32_t chr,
- llama_grammar_stacks & stacks_new);
+void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input);
- const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
- llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+ auto & stacks_cur = llama_grammar_get_stacks(grammar);
for (const auto & cpt : cpts) {
- const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
-
- llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
+ llama_grammar_accept(grammar, cpt);
if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
auto * grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
- const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
+ const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
}
}
- llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- if (grammar == nullptr)
- {
+ llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}