]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add left recursion check: quit early instead of going into an infinite loop (#7083)
authorHaggai Nuchi <redacted>
Tue, 14 May 2024 05:25:56 +0000 (22:25 -0700)
committerGitHub <redacted>
Tue, 14 May 2024 05:25:56 +0000 (15:25 +1000)
* Add left recursion check: quit early instead of going into an infinite loop

* Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty

* Remove unnecessary declaration

llama.cpp
tests/test-grammar-integration.cpp

index 202bf94c80e144c9c4d3e280197d9b87f00ae6ba..01a35dfb638f4008231d4f94613b4fe9b1a6da97 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -13182,6 +13182,58 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
     return rejects;
 }
 
+static bool llama_grammar_detect_left_recursion(
+        const std::vector<std::vector<llama_grammar_element>> & rules,
+        size_t                                                  rule_index,
+        std::vector<bool>                                     * rules_visited,
+        std::vector<bool>                                     * rules_in_progress,
+        std::vector<bool>                                     * rules_may_be_empty) {
+    if ((*rules_in_progress)[rule_index]) {
+        return true;
+    }
+
+    (*rules_in_progress)[rule_index] = true;
+
+    const std::vector<llama_grammar_element> & rule = rules[rule_index];
+
+    // First check if the rule might produce the empty string. This could be done combined with the second
+    // step but it's more readable as two steps.
+    bool at_rule_start = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            if (at_rule_start) {
+                (*rules_may_be_empty)[rule_index] = true;
+                break;
+            }
+            at_rule_start = true;
+        } else {
+            at_rule_start = false;
+        }
+    }
+
+    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+    // be empty)
+    bool recurse_into_nonterminal = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+                return true;
+            }
+            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+                recurse_into_nonterminal = false;
+            }
+        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            recurse_into_nonterminal = true;
+        } else {
+            recurse_into_nonterminal = false;
+        }
+    }
+
+    (*rules_in_progress)[rule_index] = false;
+    (*rules_visited)[rule_index] = true;
+    return false;
+}
+
 //
 // grammar - external
 //
@@ -13201,6 +13253,19 @@ struct llama_grammar * llama_grammar_init(
         vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
     }
 
+    // Check for left recursion
+    std::vector<bool> rules_visited(n_rules);
+    std::vector<bool> rules_in_progress(n_rules);
+    std::vector<bool> rules_may_be_empty(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        if (rules_visited[i]) {
+            continue;
+        }
+        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+            throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
+        }
+    }
+
     // loop over alternates of start rule to build initial stacks
     std::vector<std::vector<const llama_grammar_element *>> stacks;
     pos = vec_rules[start_rule_index].data();
@@ -13223,6 +13288,9 @@ struct llama_grammar * llama_grammar_init(
         }
     } while (true);
 
+    // Important: vec_rules has to be moved here, not copied, because stacks contains
+    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+    // then the pointers would be invalidated when the local vec_rules goes out of scope.
     return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
 }
 
index 1a4004e2ab1755d9bbfedcd3e6079b03994fa77f..01c5bb27aabb9555f522bb815147ad29edffbc2c 100644 (file)
@@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
     return grammar;
 }
 
+static bool test_build_grammar_fails(const std::string & grammar_str) {
+    fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
+    bool grammar_fails = false;
+    try {
+        build_grammar(grammar_str);
+        fprintf(stderr, "  ❌ Expected build failure, but succeeded\n");
+    } catch (const std::exception & err) {
+        grammar_fails = true;
+        fprintf(stdout, "  ✅︎\n");
+    }
+    return grammar_fails;
+}
+
 static bool match_string(const std::string & input, llama_grammar* grammar) {
     auto decoded = decode_utf8(input, {});
 
@@ -320,6 +333,38 @@ number ::= [0-9]+)""";
     fprintf(stderr, "  ✅︎ Passed\n");
 }
 
+static void test_failure_left_recursion() {
+    fprintf(stderr, "⚫ Testing left recursion detection:\n");
+
+    // Test simple left recursion detection
+    const std::string simple_str = R"""(root ::= "a" | root "a")""";
+    assert(test_build_grammar_fails(simple_str));
+
+    // Test more complicated left recursion detection
+    const std::string medium_str = R"""(
+root ::= asdf
+asdf ::= "a" | asdf "a"
+)""";
+    assert(test_build_grammar_fails(medium_str));
+
+    // Test even more complicated left recursion detection
+    const std::string hard_str = R"""(
+root ::= asdf
+asdf ::= "a" | foo "b"
+foo ::= "c" | asdf "d" | "e")""";
+    assert(test_build_grammar_fails(hard_str));
+
+    // Test yet even more complicated left recursion detection
+    const std::string hardest_str = R"""(
+root ::= asdf
+asdf ::= "a" | foo "b"
+foo ::= "c" | empty asdf "d" | "e"
+empty ::= "blah" | )""";
+    assert(test_build_grammar_fails(hardest_str));
+
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
 int main() {
     fprintf(stdout, "Running grammar integration tests...\n");
     test_simple_grammar();
@@ -327,6 +372,7 @@ int main() {
     test_quantifiers();
     test_failure_missing_root();
     test_failure_missing_reference();
+    test_failure_left_recursion();
     fprintf(stdout, "All tests passed.\n");
     return 0;
 }