]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`common`: utils to split / join / repeat strings (from json converter) (#11342)
authorOlivier Chafik <redacted>
Wed, 22 Jan 2025 09:51:44 +0000 (09:51 +0000)
committerGitHub <redacted>
Wed, 22 Jan 2025 09:51:44 +0000 (09:51 +0000)
* Factor string_join, string_split, string_repeat into common

* json: refactor to surface a versatile builder

* Update common.cpp

common/common.cpp
common/common.h
common/json-schema-to-grammar.cpp
common/json-schema-to-grammar.h

index 727ab0a109ec8f95798f59c1763d749c34676de3..6dea8e3d25238f0e6d2564ff49efc2a73ba11b80 100644 (file)
@@ -484,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
     s = std::move(builder);
 }
 
+std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
+    std::ostringstream result;
+    for (size_t i = 0; i < values.size(); ++i) {
+        if (i > 0) {
+            result << separator;
+        }
+        result << values[i];
+    }
+    return result.str();
+}
+
+std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
+    std::vector<std::string> parts;
+    size_t start = 0;
+    size_t end = str.find(delimiter);
+
+    while (end != std::string::npos) {
+        parts.push_back(str.substr(start, end - start));
+        start = end + delimiter.length();
+        end = str.find(delimiter, start);
+    }
+
+    parts.push_back(str.substr(start));
+
+    return parts;
+}
+
+std::string string_repeat(const std::string & str, size_t n) {
+    if (n == 0) {
+        return "";
+    }
+
+    std::string result;
+    result.reserve(str.length() * n);
+
+    for (size_t i = 0; i < n; ++i) {
+        result += str;
+    }
+
+    return result;
+}
+
 std::string string_from(bool value) {
     return value ? "true" : "false";
 }
index 7c9d73ce1e49e85b32f1d1932b65723ca552a214..571260372090f7882b6fbc017c58a1059f0d8615 100644 (file)
@@ -429,6 +429,10 @@ std::string string_format(const char * fmt, ...);
 std::string string_strip(const std::string & str);
 std::string string_get_sortable_timestamp();
 
+std::string string_join(const std::vector<std::string> & values, const std::string & separator);
+std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
+std::string string_repeat(const std::string & str, size_t n);
+
 void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
 
 template<class T>
index dadc18c8b352f678326c8c8da6a506dfdcca05a4..4d426b6bd1e7d19d758705ca44c52b78a6cd8ea7 100644 (file)
@@ -1,4 +1,6 @@
 #include "json-schema-to-grammar.h"
+#include "common.h"
+
 #include <algorithm>
 #include <fstream>
 #include <map>
 
 using json = nlohmann::ordered_json;
 
-template <typename Iterator>
-static std::string join(Iterator begin, Iterator end, const std::string & separator);
-
-static std::string repeat(const std::string & str, size_t n);
-
 static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
     auto has_max = max_items != std::numeric_limits<int>::max();
 
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
                 if (sub_len > 0) {
                     auto from_sub = from.substr(i + 1);
                     auto to_sub = to.substr(i + 1);
-                    auto sub_zeros = repeat("0", sub_len);
-                    auto sub_nines = repeat("9", sub_len);
+                    auto sub_zeros = string_repeat("0", sub_len);
+                    auto sub_nines = string_repeat("9", sub_len);
 
                     auto to_reached = false;
                     out << "(";
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
         auto max_digits = max_s.length();
 
         for (auto digits = min_digits; digits < max_digits; digits++) {
-            uniform_range(min_s, repeat("9", digits));
-            min_s = "1" + repeat("0", digits);
+            uniform_range(min_s, string_repeat("9", digits));
+            min_s = "1" + string_repeat("0", digits);
             out << " | ";
         }
         uniform_range(min_s, max_s);
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
 std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
 std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
 
-template <typename Iterator>
-std::string join(Iterator begin, Iterator end, const std::string & separator) {
-    std::ostringstream result;
-    if (begin != end) {
-        result << *begin;
-        for (Iterator it = begin + 1; it != end; ++it) {
-            result << separator << *it;
-        }
-    }
-    return result.str();
-}
-
-static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
-    std::vector<std::string> tokens;
-    size_t start = 0;
-    size_t end = str.find(delimiter);
-
-    while (end != std::string::npos) {
-        tokens.push_back(str.substr(start, end - start));
-        start = end + delimiter.length();
-        end = str.find(delimiter, start);
-    }
-
-    tokens.push_back(str.substr(start));
-
-    return tokens;
-}
-
-static std::string repeat(const std::string & str, size_t n) {
-    if (n == 0) {
-        return "";
-    }
-
-    std::string result;
-    result.reserve(str.length() * n);
-
-    for (size_t i = 0; i < n; ++i) {
-        result += str;
-    }
-
-    return result;
-}
-
 static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch  &)> & replacement) {
     std::smatch match;
     std::string result;
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
 
 class SchemaConverter {
 private:
+    friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
     std::function<json(const std::string &)> _fetch_json;
     bool _dotall;
     std::map<std::string, std::string> _rules;
@@ -418,7 +373,7 @@ private:
         for (size_t i = 0; i < alt_schemas.size(); i++) {
             rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
         }
-        return join(rules.begin(), rules.end(), " | ");
+        return string_join(rules, " | ");
     }
 
     std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -481,7 +436,7 @@ private:
                 for (const auto & item : ret) {
                     results.push_back(to_rule(item));
                 }
-                return std::make_pair(join(results.begin(), results.end(), " "), false);
+                return std::make_pair(string_join(results, " "), false);
             };
 
             while (i < length) {
@@ -539,7 +494,7 @@ private:
                     }
                     curly_brackets += '}';
                     i++;
-                    auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
+                    auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
                     int min_times = 0;
                     int max_times = std::numeric_limits<int>::max();
                     try {
@@ -854,7 +809,7 @@ public:
                             return;
                         }
                         std::string pointer = ref.substr(ref.find('#') + 1);
-                        std::vector<std::string> tokens = split(pointer, "/");
+                        std::vector<std::string> tokens = string_split(pointer, "/");
                         for (size_t i = 1; i < tokens.size(); ++i) {
                             std::string sel = tokens[i];
                             if (target.is_null() || !target.contains(sel)) {
@@ -905,7 +860,7 @@ public:
             for (const auto & v : schema["enum"]) {
                 enum_values.push_back(_generate_constant_rule(v));
             }
-            return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
+            return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
         } else if ((schema_type.is_null() || schema_type == "object")
                 && (schema.contains("properties") ||
                     (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -1019,10 +974,10 @@ public:
 
     void check_errors() {
         if (!_errors.empty()) {
-            throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
+            throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
         }
         if (!_warnings.empty()) {
-            fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
+            fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
         }
     }
 
@@ -1036,10 +991,27 @@ public:
 };
 
 std::string json_schema_to_grammar(const json & schema) {
-    SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
-    auto copy = schema;
-    converter.resolve_refs(copy, "input");
-    converter.visit(copy, "");
+    return build_grammar([&](const llama_grammar_builder & callbacks) {
+        auto copy = schema;
+        callbacks.resolve_refs(copy);
+        callbacks.add_schema("", copy);
+    });
+}
+
+std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
+    SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
+    llama_grammar_builder builder {
+        /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
+            return converter._add_rule(name, rule);
+        },
+        /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
+            return converter.visit(schema, name == "root" ? "" : name);
+        },
+        /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
+            converter.resolve_refs(schema, "");
+        }
+    };
+    cb(builder);
     converter.check_errors();
     return converter.format_grammar();
 }
index 41623b34645287ab14a24bcb729c79141129cf77..4f43ab3a52360963082b0066c7da08f3380b05ac 100644 (file)
@@ -5,4 +5,12 @@
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 
-std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
+std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
+
+struct llama_grammar_builder {
+    std::function<std::string(const std::string &, const std::string &)> add_rule;
+    std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
+    std::function<void(nlohmann::ordered_json &)> resolve_refs;
+};
+
+std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);