]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
json : support `enum` values within `allOf` (#15830)
authorAldehir Rojas <redacted>
Mon, 8 Sep 2025 21:14:32 +0000 (16:14 -0500)
committerGitHub <redacted>
Mon, 8 Sep 2025 21:14:32 +0000 (16:14 -0500)
common/json-schema-to-grammar.cpp
examples/json_schema_to_grammar.py
tests/test-json-schema-to-grammar.cpp
tools/server/public_legacy/json-schema-to-grammar.mjs

index 637891f50699c6028df85dfec4dabc224b2e2b0b..182c787544f4e14fb0c969307b9bd41ee399e2b1 100644 (file)
@@ -843,9 +843,10 @@ public:
                 _build_object_rule(
                     properties, required, name,
                     schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
-        } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
+        } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
             std::unordered_set<std::string> required;
             std::vector<std::pair<std::string, json>> properties;
+            std::map<std::string, size_t> enum_values;
             std::string hybrid_name = name;
             std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
                 if (comp_schema.contains("$ref")) {
@@ -857,6 +858,14 @@ public:
                             required.insert(prop.key());
                         }
                     }
+                } else if (comp_schema.contains("enum")) {
+                    for (const auto & v : comp_schema["enum"]) {
+                        const auto rule = _generate_constant_rule(v);
+                        if (enum_values.find(rule) == enum_values.end()) {
+                            enum_values[rule] = 0;
+                        }
+                        enum_values[rule] += 1;
+                    }
                 } else {
                   // todo warning
                 }
@@ -870,6 +879,17 @@ public:
                     add_component(t, true);
                 }
             }
+            if (!enum_values.empty()) {
+                std::vector<std::string> enum_intersection;
+                for (const auto & p : enum_values) {
+                    if (p.second == schema["allOf"].size()) {
+                        enum_intersection.push_back(p.first);
+                    }
+                }
+                if (!enum_intersection.empty()) {
+                    return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
+                }
+            }
             return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
         } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
             json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
index ed379585546c24c68ffed7d15de7f5954487ea5d..2d57549046b882eefa8512f945d110e002eec981 100755 (executable)
@@ -586,9 +586,10 @@ class SchemaConverter:
             properties = list(schema.get('properties', {}).items())
             return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
 
-        elif schema_type in (None, 'object') and 'allOf' in schema:
+        elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
             required = set()
             properties = []
+            enum_sets = []
             hybrid_name = name
             def add_component(comp_schema, is_required):
                 if (ref := comp_schema.get('$ref')) is not None:
@@ -600,6 +601,9 @@ class SchemaConverter:
                         if is_required:
                             required.add(prop_name)
 
+                if 'enum' in comp_schema:
+                    enum_sets.append(set(comp_schema['enum']))
+
             for t in schema['allOf']:
                 if 'anyOf' in t:
                     for tt in t['anyOf']:
@@ -607,6 +611,15 @@ class SchemaConverter:
                 else:
                     add_component(t, is_required=True)
 
+            if enum_sets:
+                enum_intersection = enum_sets[0]
+                for s in enum_sets[1:]:
+                    enum_intersection &= s
+
+                if enum_intersection:
+                    rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
+                    return self._add_rule(rule_name, rule)
+
             return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
 
         elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
index 78ee55e246f3d4d34fb36f0f4be94a0008fcc0d6..67df240c6fef3b02a87cb7b066c099dec9cf6411 100755 (executable)
@@ -1209,6 +1209,51 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         )"""
     });
 
+    test({
+        SUCCESS,
+        "allOf with enum schema",
+        R"""({
+            "allOf": [
+                {"$ref": "#/definitions/foo"}
+            ],
+            "definitions": {
+                "foo": {
+                    "type": "string",
+                    "enum": ["a", "b"]
+                }
+            }
+        })""",
+        R"""(
+            root ::= ("\"a\"" | "\"b\"") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
+    test({
+        SUCCESS,
+        "allOf with multiple enum schemas",
+        R"""({
+            "allOf": [
+                {"$ref": "#/definitions/foo"},
+                {"$ref": "#/definitions/bar"}
+            ],
+            "definitions": {
+                "foo": {
+                    "type": "string",
+                    "enum": ["a", "b", "c"]
+                },
+                "bar": {
+                    "type": "string",
+                    "enum": ["b", "c", "d"]
+                }
+            }
+        })""",
+        R"""(
+            root ::= ("\"b\"" | "\"c\"") space
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
+        )"""
+    });
+
     test({
         SUCCESS,
         "conflicting names",
index b12bf2ab0909ac14c304652e796b547f92e167f7..6f0952974496a77507759b1d87fffbbdb00ac4aa 100644 (file)
@@ -631,9 +631,10 @@ export class SchemaConverter {
       const required = new Set(schema.required || []);
       const properties = Object.entries(schema.properties ?? {});
       return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
-    } else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
+    } else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
       const required = new Set();
       const properties = [];
+      const enumSets = [];
       const addComponent = (compSchema, isRequired) => {
         const ref = compSchema.$ref;
         if (ref !== undefined) {
@@ -648,6 +649,10 @@ export class SchemaConverter {
             }
           }
         }
+
+        if ('enum' in compSchema) {
+          enumSets.push(new Set(compSchema.enum || []));
+        }
       };
 
       for (const t of schema.allOf) {
@@ -660,6 +665,14 @@ export class SchemaConverter {
         }
       }
 
+      if (enumSets.length > 0) {
+        const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
+        if (enumIntersection.size > 0) {
+          const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
+          const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
+          return this._addRule(ruleName, rule);
+        }
+      }
       return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
     } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
       const items = schema.items ?? schema.prefixItems;