]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
grammar : fix JSON Schema for string regex with top-level alt. (#9903)
authorJoe Eli McIlvain <redacted>
Wed, 16 Oct 2024 16:03:24 +0000 (09:03 -0700)
committerGitHub <redacted>
Wed, 16 Oct 2024 16:03:24 +0000 (19:03 +0300)
Prior to this commit, using a JSON Schema containing a string
with `pattern` regular expression that uses top-level alternation
(e.g. `"pattern": "^A|B|C|D$"`) would result in invalid JSON
output from the constrained sampling grammar, because it
ended up creating a grammar rule like this for the string:

```
thing ::= "\"" "A" | "B" | "C" | "D" "\"" space
```

Note that this rule will only match a starting quote for the "A" case,
and will only match an ending quote for the "D" case,
so this rule will always produce invalid JSON when used for sampling
(that is, the JSON will always be lacking the starting quote,
the ending quote, or both).

This was fixed in a simple way by adding parentheses to the
generated rule (for all string pattern rules, to keep it simple),
such that the new generated rule looks like this (correct):

```
thing ::= "\"" ("A" | "B" | "C" | "D") "\"" space
```

common/json-schema-to-grammar.cpp
examples/json_schema_to_grammar.py
examples/server/public/json-schema-to-grammar.mjs
tests/test-json-schema-to-grammar.cpp

index 881eb49e3389e1bddccb0766da6c00f7575a1a85..dadc18c8b352f678326c8c8da6a506dfdcca05a4 100644 (file)
@@ -611,7 +611,7 @@ private:
             }
             return join_seq();
         };
-        return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
+        return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
     }
 
     /*
index a8779bf3bca38ee211dd5ef21f6327db91db78ae..fc9f0097f5f8fea03ca7fb83973825105258705b 100755 (executable)
@@ -540,7 +540,7 @@ class SchemaConverter:
         return self._add_rule(
             name,
             to_rule(transform()) if self._raw_pattern \
-                else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
+                else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
 
 
     def _resolve_ref(self, ref):
index 7267f3f9c7fada1bb0ff1a5a7d44ad65a695a727..e67bb15c176333fc9ebb6af48eab73e4986599bc 100644 (file)
@@ -529,7 +529,7 @@ export class SchemaConverter {
       return joinSeq();
     };
 
-    return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
+    return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space")
   }
 
   _notStrings(strings) {
index 3a89598a82edb04777c4a810d9fe962a66727b05..9d2db91f52c35de2844174c031e18393ff3811d5 100755 (executable)
@@ -696,7 +696,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             "pattern": "^abc?d*efg+(hij)?kl$"
         })""",
         R"""(
-            root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
+            root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
             space ::= | " " | "\n" [ \t]{0,20}
         )"""
     });
@@ -709,7 +709,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             "pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
         })""",
         R"""(
-            root ::= "\"" "[]{}()|+*?" "\"" space
+            root ::= "\"" ("[]{}()|+*?") "\"" space
             space ::= | " " | "\n" [ \t]{0,20}
         )"""
     });
@@ -722,7 +722,20 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             "pattern": "^\"$"
         })""",
         R"""(
-            root ::= "\"" "\"" "\"" space
+            root ::= "\"" ("\"") "\"" space
+            space ::= | " " | "\n" [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "regexp with top-level alternation",
+        R"""({
+            "type": "string",
+            "pattern": "^A|B|C|D$"
+        })""",
+        R"""(
+            root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
             space ::= | " " | "\n" [ \t]{0,20}
         )"""
     });
@@ -736,7 +749,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             dot ::= [^\x0A\x0D]
-            root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
+            root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
             root-1 ::= [0-9]
             space ::= | " " | "\n" [ \t]{0,20}
         )"""