return rejects;
}
+static bool llama_grammar_detect_left_recursion(
+ const std::vector<std::vector<llama_grammar_element>> & rules,
+ size_t rule_index,
+ std::vector<bool> * rules_visited,
+ std::vector<bool> * rules_in_progress,
+ std::vector<bool> * rules_may_be_empty) {
+ if ((*rules_in_progress)[rule_index]) {
+ return true;
+ }
+
+ (*rules_in_progress)[rule_index] = true;
+
+ const std::vector<llama_grammar_element> & rule = rules[rule_index];
+
+ // First check if the rule might produce the empty string. This could be done combined with the second
+ // step but it's more readable as two steps.
+ bool at_rule_start = true;
+ for (size_t i = 0; i < rule.size(); i++) {
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
+ if (at_rule_start) {
+ (*rules_may_be_empty)[rule_index] = true;
+ break;
+ }
+ at_rule_start = true;
+ } else {
+ at_rule_start = false;
+ }
+ }
+
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+ // be empty)
+ bool recurse_into_nonterminal = true;
+ for (size_t i = 0; i < rule.size(); i++) {
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+ return true;
+ }
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+ recurse_into_nonterminal = false;
+ }
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+ recurse_into_nonterminal = true;
+ } else {
+ recurse_into_nonterminal = false;
+ }
+ }
+
+ (*rules_in_progress)[rule_index] = false;
+ (*rules_visited)[rule_index] = true;
+ return false;
+}
+
//
// grammar - external
//
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
}
+ // Check for left recursion
+ std::vector<bool> rules_visited(n_rules);
+ std::vector<bool> rules_in_progress(n_rules);
+ std::vector<bool> rules_may_be_empty(n_rules);
+ for (size_t i = 0; i < n_rules; i++) {
+ if (rules_visited[i]) {
+ continue;
+ }
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+ throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
+ }
+ }
+
// loop over alternates of start rule to build initial stacks
std::vector<std::vector<const llama_grammar_element *>> stacks;
pos = vec_rules[start_rule_index].data();
}
} while (true);
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}
return grammar;
}
+static bool test_build_grammar_fails(const std::string & grammar_str) {
+ fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
+ bool grammar_fails = false;
+ try {
+ build_grammar(grammar_str);
+ fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
+ } catch (const std::exception & err) {
+ grammar_fails = true;
+ fprintf(stdout, " ✅︎\n");
+ }
+ return grammar_fails;
+}
+
static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});
fprintf(stderr, " ✅︎ Passed\n");
}
+static void test_failure_left_recursion() {
+ fprintf(stderr, "⚫ Testing left recursion detection:\n");
+
+ // Test simple left recursion detection
+ const std::string simple_str = R"""(root ::= "a" | root "a")""";
+ assert(test_build_grammar_fails(simple_str));
+
+ // Test more complicated left recursion detection
+ const std::string medium_str = R"""(
+root ::= asdf
+asdf ::= "a" | asdf "a"
+)""";
+ assert(test_build_grammar_fails(medium_str));
+
+ // Test even more complicated left recursion detection
+ const std::string hard_str = R"""(
+root ::= asdf
+asdf ::= "a" | foo "b"
+foo ::= "c" | asdf "d" | "e")""";
+ assert(test_build_grammar_fails(hard_str));
+
+ // Test yet even more complicated left recursion detection
+ const std::string hardest_str = R"""(
+root ::= asdf
+asdf ::= "a" | foo "b"
+foo ::= "c" | empty asdf "d" | "e"
+empty ::= "blah" | )""";
+ assert(test_build_grammar_fails(hardest_str));
+
+ fprintf(stderr, " ✅︎ Passed\n");
+}
+
int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
+ test_failure_left_recursion();
fprintf(stdout, "All tests passed.\n");
return 0;
}