#include <cmath>
#include <algorithm>
#include <cstdint>
+#include <set>
#include <stdexcept>
#define MAX_REPETITION_THRESHOLD 2000
bool is_nested) {
size_t last_sym_start = rule.size();
const char * pos = src;
+ uint64_t n_prev_rules = 1;
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
// (though it's technically the same as -1 now)
// S' ::= S |
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+ // Calculate the total number of rules that will be generated by this repetition
+ uint64_t total_rules = 1; // Start with 1 for the original rule
+ if (!no_max && max_times > 0) {
+ total_rules = max_times;
+ } else if (min_times > 0) {
+ total_rules = min_times;
+ }
+
+ if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) {
+ throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity");
+ }
+
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
+ n_prev_rules *= total_rules;
+ GGML_ASSERT(n_prev_rules >= 1);
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
+ n_prev_rules = 1;
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
+ n_prev_rules = 1;
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
auto token_pair = parse_token(vocab, pos);
const char * token_end = token_pair.second;
last_sym_start = rule.size();
+ n_prev_rules = 1;
rule.push_back({type, token_pair.first});
pos = parse_space(token_end, is_nested);
} else if (is_word_char(*pos)) { // rule reference
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
+ n_prev_rules = 1;
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
+ uint32_t n_rules_before = symbol_ids.size();
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+ n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
+ n_prev_rules = 1;
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
static void llama_grammar_advance_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
- llama_grammar_stacks & new_stacks) {
- if (stack.empty()) {
- if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
- new_stacks.emplace_back(stack);
+ llama_grammar_stacks & new_stacks) {
+ std::vector<llama_grammar_stack> todo;
+ todo.push_back(stack);
+
+ auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) {
+ return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(),
+ [](const llama_grammar_element * pa, const llama_grammar_element * pb) {
+ return pa < pb; // Compare pointer addresses
+ }
+ );
+ };
+
+ std::set<llama_grammar_stack, decltype(stack_cmp)> seen(stack_cmp);
+
+ while (!todo.empty()) {
+ llama_grammar_stack curr_stack = std::move(todo.back());
+ todo.pop_back();
+
+ if (seen.find( curr_stack) != seen.end()) {
+ continue;
}
- return;
- }
+ seen.insert(curr_stack);
- const llama_grammar_element * pos = stack.back();
+ if (curr_stack.empty()) {
+ if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
+ new_stacks.emplace_back(std::move(curr_stack));
+ }
+ continue;
+ }
- switch (pos->type) {
+ const llama_grammar_element * pos = curr_stack.back();
+
+ switch (pos->type) {
case LLAMA_GRETYPE_RULE_REF: {
const size_t rule_id = static_cast<size_t>(pos->value);
const llama_grammar_element * subpos = rules[rule_id].data();
do {
// init new stack without the top (pos)
- llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+ llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
// if this rule ref is followed by another element, add that to stack
- new_stack.push_back(pos + 1);
+ next_stack.push_back(pos + 1);
}
if (!llama_grammar_is_end_of_sequence(subpos)) {
// if alternate is nonempty, add to stack
- new_stack.push_back(subpos);
+ next_stack.push_back(subpos);
}
- llama_grammar_advance_stack(rules, new_stack, new_stacks);
+ todo.push_back(std::move(next_stack));
while (!llama_grammar_is_end_of_sequence(subpos)) {
// scan to end of alternate def
subpos++;
case LLAMA_GRETYPE_CHAR_ANY:
case LLAMA_GRETYPE_TOKEN:
case LLAMA_GRETYPE_TOKEN_NOT:
- if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
+ if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have
- new_stacks.emplace_back(stack);
+ new_stacks.emplace_back(std::move(curr_stack));
}
break;
default:
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
// those
GGML_ABORT("fatal error");
+ }
}
}
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
- {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 97},
+ {LLAMA_GRETYPE_CHAR, 40},
},
{
- {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
- {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
- {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 40},
+ {LLAMA_GRETYPE_CHAR, 97},
},
{
+ {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 97},
+ {LLAMA_GRETYPE_CHAR, 40},
},
{
+ {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
+ {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
+ {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 40},
+ {LLAMA_GRETYPE_CHAR, 97},
}};
auto index = 0;
}
std::vector<llama_grammar_candidate> next_candidates;
- next_candidates.resize(24);
+ next_candidates.resize(23);
- for (size_t i = 0; i < 24; ++i)
+ for (size_t i = 0; i < 23; ++i)
{
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
cp[0] = 37 + i;
{0, 37},
{1, 38},
{2, 39},
- {3, 40},
{4, 41},
{5, 42},
{6, 43},
{0, 37},
{1, 38},
{2, 39},
+ {3, 40},
{4, 41},
{5, 42},
{6, 43},
{20, 57},
{21, 58},
{22, 59},
- {23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
- {3, 40},
{4, 41},
{5, 42},
{6, 43},
{0, 37},
{1, 38},
{2, 39},
+ {3, 40},
{4, 41},
{5, 42},
{6, 43},
{20, 57},
{21, 58},
{22, 59},
- {23, 60},
},
};