--- /dev/null
+#include "regex-partial.h"
+#include "common.h"
+#include <functional>
+#include <optional>
+
+common_regex::common_regex(const std::string & pattern) :
+ pattern(pattern),
+ rx(pattern),
+ rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
+
+common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
+ std::smatch match;
+ if (pos > input.size()) {
+ throw std::runtime_error("Position out of bounds");
+ }
+ auto start = input.begin() + pos;
+ auto found = as_match
+ ? std::regex_match(start, input.end(), match, rx)
+ : std::regex_search(start, input.end(), match, rx);
+ if (found) {
+ common_regex_match res;
+ res.type = COMMON_REGEX_MATCH_TYPE_FULL;
+ for (size_t i = 0; i < match.size(); ++i) {
+ auto begin = pos + match.position(i);
+ res.groups.emplace_back(begin, begin + match.length(i));
+ }
+ return res;
+ }
+ std::match_results<std::string::const_reverse_iterator> srmatch;
+ if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
+ auto group = srmatch[1].str();
+ if (group.length() != 0) {
+ auto it = srmatch[1].second.base();
+ // auto position = static_cast<size_t>(std::distance(input.begin(), it));
+ if ((!as_match) || it == input.begin()) {
+ common_regex_match res;
+ res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
+ const size_t begin = std::distance(input.begin(), it);
+ const size_t end = input.size();
+ if (begin == std::string::npos || end == std::string::npos || begin > end) {
+ throw std::runtime_error("Invalid range");
+ }
+ res.groups.push_back({begin, end});
+ return res;
+ }
+ }
+ }
+ return {};
+}
+
+/*
+ Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
+
+ Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
+ to see if a string ends with a partial regex match, but but it's not in std::regex yet.
+ Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
+
+ - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
+ - /a|b/ -> (a|b).*
+ - /a*?/ -> error, could match ""
+ - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
+ - /.*?ab/ -> ((?:b)?a).* (merge .*)
+ - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
+ - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
+ - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
+ - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
+
+ The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
+ (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
+*/
+std::string regex_to_reversed_partial_regex(const std::string & pattern) {
+ auto it = pattern.begin();
+ const auto end = pattern.end();
+
+ std::function<std::string()> process = [&]() {
+ std::vector<std::vector<std::string>> alternatives(1);
+ std::vector<std::string> * sequence = &alternatives.back();
+
+ while (it != end) {
+ if (*it == '[') {
+ auto start = it;
+ ++it;
+ while (it != end) {
+ if ((*it == '\\') && (++it != end)) {
+ ++it;
+ } else if ((it != end) && (*it == ']')) {
+ break;
+ } else {
+ ++it;
+ }
+ }
+ if (it == end) {
+ throw std::runtime_error("Unmatched '[' in pattern");
+ }
+ ++it;
+ sequence->push_back(std::string(start, it));
+ } else if (*it == '*' || *it == '?' || *it == '+') {
+ if (sequence->empty()) {
+ throw std::runtime_error("Quantifier without preceding element");
+ }
+ sequence->back() += *it;
+ auto is_star = *it == '*';
+ ++it;
+ if (is_star) {
+ if (*it == '?') {
+ ++it;
+ }
+ }
+ } else if (*it == '{') {
+ if (sequence->empty()) {
+ throw std::runtime_error("Repetition without preceding element");
+ }
+ ++it;
+ auto start = it;
+ while (it != end && *it != '}') {
+ ++it;
+ }
+ if (it == end) {
+ throw std::runtime_error("Unmatched '{' in pattern");
+ }
+ auto parts = string_split(std::string(start, it), ",");
+ ++it;
+ if (parts.size() > 2) {
+ throw std::runtime_error("Invalid repetition range in pattern");
+ }
+
+ auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
+ if (s.empty()) {
+ return def;
+ }
+ return std::stoi(s);
+ };
+ auto min = parseOptInt(parts[0], 0);
+ auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
+ if (min && max && *max < *min) {
+ throw std::runtime_error("Invalid repetition range in pattern");
+ }
+ // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
+ auto part = sequence->back();
+ sequence->pop_back();
+ for (int i = 0; i < *min; i++) {
+ sequence->push_back(part);
+ }
+ if (max) {
+ for (int i = *min; i < *max; i++) {
+ sequence->push_back(part + "?");
+ }
+ } else {
+ sequence->push_back(part + "*");
+ }
+ } else if (*it == '(') {
+ ++it;
+ if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
+ it += 2;
+ }
+ auto sub = process();
+ if (*it != ')') {
+ throw std::runtime_error("Unmatched '(' in pattern");
+ }
+ ++it;
+ auto & part = sequence->emplace_back("(?:");
+ part += sub;
+ part += ")";
+ } else if (*it == ')') {
+ break;
+ } else if (*it == '|') {
+ ++it;
+ alternatives.emplace_back();
+ sequence = &alternatives.back();
+ } else if (*it == '\\' && (++it != end)) {
+ auto str = std::string("\\") + *it;
+ sequence->push_back(str);
+ ++it;
+ } else if (it != end) {
+ sequence->push_back(std::string(1, *it));
+ ++it;
+ }
+ }
+
+ // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
+ // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
+ // We'll do the outermost capturing group and final .* in the enclosing function.
+ std::vector<std::string> res_alts;
+ for (const auto & parts : alternatives) {
+ auto & res = res_alts.emplace_back();
+ for (size_t i = 0; i < parts.size() - 1; i++) {
+ res += "(?:";
+ }
+ for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
+ res += *it;
+ if (it != parts.rend() - 1) {
+ res += ")?";
+ }
+ }
+ }
+ return string_join(res_alts, "|");
+ };
+ auto res = process();
+ if (it != end) {
+ throw std::runtime_error("Unmatched '(' in pattern");
+ }
+
+ return "(" + res + ")[\\s\\S]*";
+}
--- /dev/null
+// Tests common_regex (esp. its partial final matches support).
+
+#include "common.h"
+#include "regex-partial.h"
+
+#include <sstream>
+#include <iostream>
+#include <optional>
+
+template <class T> static void assert_equals(const T & expected, const T & actual) {
+ if (expected != actual) {
+ std::cerr << "Expected: " << expected << std::endl;
+ std::cerr << " Actual: " << actual << std::endl;
+ std::cerr << std::flush;
+ throw std::runtime_error("Test failed");
+ }
+}
+
+struct test_case {
+ std::string pattern;
+ struct input_output {
+ std::string input;
+ common_regex_match output;
+ };
+ std::vector<input_output> inputs_outputs;
+};
+
+static std::string common_regex_match_type_name(common_regex_match_type type) {
+ switch (type) {
+ case COMMON_REGEX_MATCH_TYPE_NONE:
+ return "COMMON_REGEX_MATCH_TYPE_NONE";
+ case COMMON_REGEX_MATCH_TYPE_PARTIAL:
+ return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
+ case COMMON_REGEX_MATCH_TYPE_FULL:
+ return "COMMON_REGEX_MATCH_TYPE_FULL";
+ }
+ return "?";
+}
+
+static void test_regex() {
+ printf("[%s]\n", __func__);
+ auto test = [](const test_case & test_case) {
+ common_regex cr(test_case.pattern);
+ std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
+ // std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
+ for (const auto & input_output : test_case.inputs_outputs) {
+ std::cout << " Input: " << input_output.input << '\n';
+ auto m = cr.search(input_output.input, 0);
+ if (m != input_output.output) {
+ auto match_to_str = [&](const std::optional<common_regex_match> & m) {
+ std::ostringstream ss;
+ if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
+ ss << "<no match>";
+ } else {
+ GGML_ASSERT(!input_output.output.groups.empty());
+ std::vector<std::string> parts;
+ for (const auto & g : m->groups) {
+ parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
+ }
+ ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
+ }
+ return ss.str();
+ };
+ std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
+ std::cout << " Got: " << match_to_str(m) << '\n';
+ std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
+
+ throw std::runtime_error("Test failed");
+ }
+ }
+ };
+ test({
+ "a",
+ {
+ {"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+ {"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
+ {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+ {"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
+ }
+ });
+ test({
+ "abcd",
+ {
+ {"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+ {"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+ {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+ {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+ {"d", {}},
+ {"bcd", {}},
+ {"cde", {}},
+ {"cd", {}},
+ {"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
+ {"abbie", {}},
+ {"", {}},
+ }
+ });
+ test({
+ ".*?ab",
+ {
+ {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+ {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+ {"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+ {"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+ {"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+ }
+ });
+ test({
+ "a.*?b",
+ {
+ {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+ {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+ {"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+ {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+ {"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+ {"d", {}},
+ {"b", {}},
+ }
+ });
+ test({
+ "ab(?:cd){2,4}ef",
+ {
+ // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
+ {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+ {"abcde", {}},
+ {"abcdef", {}},
+ {"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+ {"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
+ {"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+ {"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
+ {"abcdcdcdcdcdef", {}},
+ {"abcde", {}},
+ {"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
+ }
+ });
+ test({
+ "a(?:rte| pure )fact",
+ {
+ {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+ {"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+ {"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+ {"fact", {}},
+ {"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
+ {"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+ {"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
+ {"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+ {"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
+ {"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
+ {"" , {}},
+ {"pure", {}},
+ {"pure fact", {}},
+ }
+ });
+ test({
+ "abc",
+ {
+ {" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
+ {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+ {" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
+ {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+ {"b", {}},
+ {"c", {}},
+ {"", {}},
+ }
+ });
+
+ test({
+ "(?:abc)?\\s*def",
+ {
+ {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+ {"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+ {"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+ {"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+ {"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+ {"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+ {"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+ {"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+ {"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
+ {" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+ {"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+ }
+ });
+
+ test({
+ "a+b",
+ {
+ {"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+ {"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+ {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+ }
+ });
+
+ test({
+ "(?:"
+ "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
+ "(" // match 2 (open_tag)
+ "<tool_call>"
+ "|<function_call>"
+ "|<tool>"
+ "|<tools>"
+ "|<response>"
+ "|<json>"
+ "|<xml>"
+ "|<JSON>"
+ ")?"
+ "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
+ ")"
+ "|<function=([^>]+)>" // match 4 (function name)
+ "|<function name=\"([^\"]+)\">", // match 5 (function name again)
+ {
+ {"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
+ {"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
+ {"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
+ {"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
+ {"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
+ {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+ {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
+ {"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
+ {"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
+ {"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
+ {"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
+
+ }
+ });
+}
+
+static void test_regex_to_reversed_partial_regex() {
+ printf("[%s]\n", __func__);
+
+ assert_equals<std::string>(
+ "((?:(?:c)?b)?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("abc"));
+
+ assert_equals<std::string>(
+ "(a+)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a+"));
+
+ assert_equals<std::string>(
+ "(a*)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a*"));
+
+ assert_equals<std::string>(
+ "(a?)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a?"));
+
+ assert_equals<std::string>(
+ "([a-z])[\\s\\S]*",
+ regex_to_reversed_partial_regex("[a-z]"));
+
+ assert_equals<std::string>(
+ "((?:\\w+)?[a-z])[\\s\\S]*",
+ regex_to_reversed_partial_regex("[a-z]\\w+"));
+
+ assert_equals<std::string>(
+ "((?:a|b))[\\s\\S]*",
+ regex_to_reversed_partial_regex("(?:a|b)"));
+ assert_equals<std::string>(
+ "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("abcd"));
+ assert_equals<std::string>(
+ "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+ regex_to_reversed_partial_regex("a*b"));
+ assert_equals<std::string>(
+ "((?:(?:b)?a)?.*)[\\s\\S]*",
+ regex_to_reversed_partial_regex(".*?ab"));
+ assert_equals<std::string>(
+ "((?:(?:b)?.*)?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a.*?b"));
+ assert_equals<std::string>(
+ "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a(bc)d"));
+ assert_equals<std::string>(
+ "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("a(bc|de)"));
+ assert_equals<std::string>(
+ "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+ regex_to_reversed_partial_regex("ab{2,4}c"));
+}
+
+int main() {
+ test_regex_to_reversed_partial_regex();
+ test_regex();
+ std::cout << "All tests passed.\n";
+}