]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
grammar: Fix grammar root symbol check (#19761)
authorAsbjørn Olling <redacted>
Thu, 12 Mar 2026 11:04:56 +0000 (12:04 +0100)
committerGitHub <redacted>
Thu, 12 Mar 2026 11:04:56 +0000 (12:04 +0100)
* grammar: fix bad check for root symbol, correct error logging

* add tests to demonstrate root symbol check failure

src/llama-grammar.cpp
tests/test-grammar-integration.cpp

index 3b7a625234e77cc1a8f388d71e7fae5dfd1b3f37..aac0d41f2b41a5c5e25530d485f2ebe7c80d3165 100644 (file)
@@ -1160,13 +1160,13 @@ struct llama_grammar * llama_grammar_init_impl(
     // if there is a grammar, parse it
     // rules will be empty (default) if there are parse errors
     if (!parser.parse(grammar_str) || parser.rules.empty()) {
-        fprintf(stderr, "%s: failed to parse grammar\n", __func__);
+        LLAMA_LOG_ERROR("failed to parse grammar\n");
         return nullptr;
     }
 
-    // Ensure that there is a "root" node.
-    if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
-        fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
+    // Ensure that the grammar contains the start symbol
+    if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) {
+        LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root);
         return nullptr;
     }
 
@@ -1195,7 +1195,7 @@ struct llama_grammar * llama_grammar_init_impl(
             continue;
         }
         if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i);
             return nullptr;
         }
     }
index 7aa7e58a5c63cfb2f38f1208d0a0f49b1912dc42..526470a224f023cfade47ed074c48155a7b0a8f9 100644 (file)
 
 using json = nlohmann::ordered_json;
 
+static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) {
+    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0);
+}
+
 static llama_grammar * build_grammar(const std::string & grammar_str) {
-    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
+    return build_grammar_with_root(grammar_str, "root");
 }
 
 static bool test_build_grammar_fails(const std::string & grammar_str) {
@@ -860,6 +864,36 @@ static void test_failure_left_recursion() {
     fprintf(stderr, "  ✅︎ Passed\n");
 }
 
+static void test_failure_missing_root_symbol() {
+    fprintf(stderr, "⚫ Testing missing root symbol:\n");
+
+    const std::string grammar_str = R"""(
+        root ::= "foobar"
+    )""";
+
+    llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent");
+    assert(failure_result == nullptr);
+
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
+static void test_custom_root_symbol_check() {
+    fprintf(stderr, "⚫ Testing custom root symbol check:\n");
+
+    const std::string custom_root_grammar_str = R"""(
+        foobar ::= "foobar"
+    )""";
+
+    llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root");
+    assert(failure_result == nullptr);
+
+    llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar");
+    assert(success_result != nullptr);
+    llama_grammar_free_impl(success_result);
+
+    fprintf(stderr, "  ✅︎ Passed\n");
+}
+
 static void test_json_schema() {
     // Note that this is similar to the regular grammar tests,
     //  but we convert each json schema to a grammar before parsing.
@@ -1433,6 +1467,8 @@ int main() {
     test_failure_missing_root();
     test_failure_missing_reference();
     test_failure_left_recursion();
+    test_failure_missing_root_symbol();
+    test_custom_root_symbol_check();
     test_json_schema();
     fprintf(stdout, "All tests passed.\n");
     return 0;