]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`common`: add partial regex support (#12808)
authorOlivier Chafik <redacted>
Wed, 14 May 2025 18:50:57 +0000 (19:50 +0100)
committerGitHub <redacted>
Wed, 14 May 2025 18:50:57 +0000 (19:50 +0100)
* move string_find_partial_stop & string_ends_with to common

* add common_regex (supports partial matches)

Co-authored-by: Georgi Gerganov <redacted>
* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update common/regex-partial.h

Co-authored-by: Georgi Gerganov <redacted>
* partial regex: add missing iterator end checks

* string utils: use string_views

* direct throw to avoid ggml.h include

* regex-partial: replace missed ggml_asserts

---------

Co-authored-by: ochafik <redacted>
Co-authored-by: Georgi Gerganov <redacted>
common/CMakeLists.txt
common/common.cpp
common/common.h
common/regex-partial.cpp [new file with mode: 0644]
common/regex-partial.h [new file with mode: 0644]
tests/CMakeLists.txt
tests/test-regex-partial.cpp [new file with mode: 0644]
tools/server/server.cpp

index 6b0011e4df84e25e583036e0eaf99de4f56e5f83..db5d4094a17b843c2afcd2d61be71a7b049b7ec4 100644 (file)
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
     minja/minja.hpp
     ngram-cache.cpp
     ngram-cache.h
+    regex-partial.cpp
+    regex-partial.h
     sampling.cpp
     sampling.h
     speculative.cpp
index 2b1d8da592ee9bdc4161ada3a061fcc2f6b85900..62e922a99c0921869d0d24acc55ed25bf6fd94a3 100644 (file)
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
     s = std::move(builder);
 }
 
+bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
+    return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
+}
+size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
+    if (!str.empty() && !stop.empty()) {
+        const char text_last_char = str.back();
+        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
+            if (stop[char_index] == text_last_char) {
+                const auto current_partial = stop.substr(0, char_index + 1);
+                if (string_ends_with(str, current_partial)) {
+                    return str.size() - char_index - 1;
+                }
+            }
+        }
+    }
+
+    return std::string::npos;
+}
+
 std::string regex_escape(const std::string & s) {
     static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
     return std::regex_replace(s, special_chars, "\\$0");
index dea34267c8de827d611c2914d508dca03678c317..a99a36029a53c2e0cad86be38421076118443fce 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <set>
 #include <string>
+#include <string_view>
 #include <vector>
 #include <sstream>
 
@@ -503,10 +504,9 @@ static bool string_starts_with(const std::string & str,
     return str.rfind(prefix, 0) == 0;
 }
 
-static bool string_ends_with(const std::string & str,
-                               const std::string & suffix) {  // While we wait for C++20's std::string::ends_with...
-    return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
-}
+// While we wait for C++20's std::string::ends_with...
+bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
+size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
 
 bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
 void string_process_escapes(std::string & input);
diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp
new file mode 100644 (file)
index 0000000..4bff6b6
--- /dev/null
@@ -0,0 +1,204 @@
+#include "regex-partial.h"
+#include "common.h"
+#include <functional>
+#include <optional>
+
+common_regex::common_regex(const std::string & pattern) :
+    pattern(pattern),
+    rx(pattern),
+    rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
+
+common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
+    std::smatch match;
+    if (pos > input.size()) {
+        throw std::runtime_error("Position out of bounds");
+    }
+    auto start = input.begin() + pos;
+    auto found = as_match
+        ? std::regex_match(start, input.end(), match, rx)
+        : std::regex_search(start, input.end(), match, rx);
+    if (found) {
+        common_regex_match res;
+        res.type = COMMON_REGEX_MATCH_TYPE_FULL;
+        for (size_t i = 0; i < match.size(); ++i) {
+            auto begin = pos + match.position(i);
+            res.groups.emplace_back(begin, begin + match.length(i));
+        }
+        return res;
+    }
+    std::match_results<std::string::const_reverse_iterator> srmatch;
+    if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
+        auto group = srmatch[1].str();
+        if (group.length() != 0) {
+            auto it = srmatch[1].second.base();
+            // auto position = static_cast<size_t>(std::distance(input.begin(), it));
+            if ((!as_match) || it == input.begin()) {
+                common_regex_match res;
+                res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
+                const size_t begin = std::distance(input.begin(), it);
+                const size_t end = input.size();
+                if (begin == std::string::npos || end == std::string::npos || begin > end) {
+                    throw std::runtime_error("Invalid range");
+                }
+                res.groups.push_back({begin, end});
+                return res;
+            }
+        }
+    }
+    return {};
+}
+
+/*
+  Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
+
+  Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
+  to see if a string ends with a partial regex match, but but it's not in std::regex yet.
+  Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
+
+  - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
+  - /a|b/ -> (a|b).*
+  - /a*?/ -> error, could match ""
+  - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
+  - /.*?ab/ -> ((?:b)?a).* (merge .*)
+  - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
+  - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
+  - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
+  - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
+
+  The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
+  (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
+*/
+std::string regex_to_reversed_partial_regex(const std::string & pattern) {
+    auto it = pattern.begin();
+    const auto end = pattern.end();
+
+    std::function<std::string()> process = [&]() {
+        std::vector<std::vector<std::string>> alternatives(1);
+        std::vector<std::string> * sequence = &alternatives.back();
+
+        while (it != end) {
+            if (*it == '[') {
+                auto start = it;
+                ++it;
+                while (it != end) {
+                    if ((*it == '\\') && (++it != end)) {
+                        ++it;
+                    } else if ((it != end) && (*it == ']')) {
+                        break;
+                    } else {
+                        ++it;
+                    }
+                }
+                if (it == end) {
+                    throw std::runtime_error("Unmatched '[' in pattern");
+                }
+                ++it;
+                sequence->push_back(std::string(start, it));
+            } else if (*it == '*' || *it == '?' || *it == '+') {
+                if (sequence->empty()) {
+                    throw std::runtime_error("Quantifier without preceding element");
+                }
+                sequence->back() += *it;
+                auto is_star = *it == '*';
+                ++it;
+                if (is_star) {
+                    if (*it == '?') {
+                        ++it;
+                    }
+                }
+            } else if (*it == '{') {
+                if (sequence->empty()) {
+                    throw std::runtime_error("Repetition without preceding element");
+                }
+                ++it;
+                auto start = it;
+                while (it != end && *it != '}') {
+                    ++it;
+                }
+                if (it == end) {
+                    throw std::runtime_error("Unmatched '{' in pattern");
+                }
+                auto parts = string_split(std::string(start, it), ",");
+                ++it;
+                if (parts.size() > 2) {
+                    throw std::runtime_error("Invalid repetition range in pattern");
+                }
+
+                auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
+                    if (s.empty()) {
+                        return def;
+                    }
+                    return std::stoi(s);
+                };
+                auto min = parseOptInt(parts[0], 0);
+                auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
+                if (min && max && *max < *min) {
+                    throw std::runtime_error("Invalid repetition range in pattern");
+                }
+                // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
+                auto part = sequence->back();
+                sequence->pop_back();
+                for (int i = 0; i < *min; i++) {
+                    sequence->push_back(part);
+                }
+                if (max) {
+                    for (int i = *min; i < *max; i++) {
+                        sequence->push_back(part + "?");
+                    }
+                } else {
+                    sequence->push_back(part + "*");
+                }
+            } else if (*it == '(') {
+                ++it;
+                if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
+                    it += 2;
+                }
+                auto sub = process();
+                if (*it != ')') {
+                    throw std::runtime_error("Unmatched '(' in pattern");
+                }
+                ++it;
+                auto & part = sequence->emplace_back("(?:");
+                part += sub;
+                part += ")";
+            } else if (*it == ')') {
+                break;
+            } else if (*it == '|') {
+                ++it;
+                alternatives.emplace_back();
+                sequence = &alternatives.back();
+            } else if (*it == '\\' && (++it != end)) {
+                auto str = std::string("\\") + *it;
+                sequence->push_back(str);
+                ++it;
+            } else if (it != end) {
+                sequence->push_back(std::string(1, *it));
+                ++it;
+            }
+        }
+
+        // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
+        // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
+        // We'll do the outermost capturing group and final .* in the enclosing function.
+        std::vector<std::string> res_alts;
+        for (const auto & parts : alternatives) {
+            auto & res = res_alts.emplace_back();
+            for (size_t i = 0; i < parts.size() - 1; i++) {
+                res += "(?:";
+            }
+            for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
+                res += *it;
+                if (it != parts.rend() - 1) {
+                    res += ")?";
+                }
+            }
+        }
+        return string_join(res_alts, "|");
+    };
+    auto res = process();
+    if (it != end) {
+        throw std::runtime_error("Unmatched '(' in pattern");
+    }
+
+    return "(" + res + ")[\\s\\S]*";
+}
diff --git a/common/regex-partial.h b/common/regex-partial.h
new file mode 100644 (file)
index 0000000..634cb40
--- /dev/null
@@ -0,0 +1,56 @@
+#pragma once
+
+#include <regex>
+#include <string>
+
+enum common_regex_match_type {
+    COMMON_REGEX_MATCH_TYPE_NONE,
+    COMMON_REGEX_MATCH_TYPE_PARTIAL,
+    COMMON_REGEX_MATCH_TYPE_FULL,
+};
+
+struct common_string_range {
+    size_t begin;
+    size_t end;
+    common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
+        if (begin > end) {
+            throw std::runtime_error("Invalid range");
+        }
+    }
+    // prevent default ctor
+    common_string_range() = delete;
+    bool empty() const {
+        return begin == end;
+    }
+    bool operator==(const common_string_range & other) const {
+        return begin == other.begin && end == other.end;
+    }
+};
+
+struct common_regex_match {
+    common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
+    std::vector<common_string_range> groups;
+
+    bool operator==(const common_regex_match & other) const {
+        return type == other.type && groups == other.groups;
+    }
+    bool operator!=(const common_regex_match & other) const {
+        return !(*this == other);
+    }
+};
+
+class common_regex {
+    std::string pattern;
+    std::regex rx;
+    std::regex rx_reversed_partial;
+
+  public:
+    explicit common_regex(const std::string & pattern);
+
+    common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
+
+    const std::string & str() const { return pattern; }
+};
+
+// For testing only (pretty print of failures).
+std::string regex_to_reversed_partial_regex(const std::string & pattern);
index 709d5ad96afba7f7fc08bb980d671c7cb1ef4d2e..083347d188880f0aea3c85592e3f4e808fd49f74 100644 (file)
@@ -144,6 +144,7 @@ endif()
 
 llama_build_and_test(test-log.cpp)
 llama_build_and_test(test-chat-template.cpp)
+llama_build_and_test(test-regex-partial.cpp)
 
 # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
 if (NOT WIN32)
diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp
new file mode 100644 (file)
index 0000000..ffad189
--- /dev/null
@@ -0,0 +1,288 @@
+//  Tests common_regex (esp. its partial final matches support).
+
+#include "common.h"
+#include "regex-partial.h"
+
+#include <sstream>
+#include <iostream>
+#include <optional>
+
+template <class T> static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "  Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+
+struct test_case {
+    std::string pattern;
+    struct input_output {
+        std::string input;
+        common_regex_match output;
+    };
+    std::vector<input_output> inputs_outputs;
+};
+
+static std::string common_regex_match_type_name(common_regex_match_type type) {
+    switch (type) {
+        case COMMON_REGEX_MATCH_TYPE_NONE:
+            return "COMMON_REGEX_MATCH_TYPE_NONE";
+        case COMMON_REGEX_MATCH_TYPE_PARTIAL:
+            return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
+        case COMMON_REGEX_MATCH_TYPE_FULL:
+            return "COMMON_REGEX_MATCH_TYPE_FULL";
+    }
+    return "?";
+}
+
+static void test_regex() {
+    printf("[%s]\n", __func__);
+    auto test = [](const test_case & test_case) {
+        common_regex cr(test_case.pattern);
+        std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
+        // std::cout << "    partial rev: " << cr.reversed_partial_pattern.str() << '\n';
+        for (const auto & input_output : test_case.inputs_outputs) {
+            std::cout << "  Input: " << input_output.input << '\n';
+            auto m = cr.search(input_output.input, 0);
+            if (m != input_output.output) {
+                auto match_to_str = [&](const std::optional<common_regex_match> & m) {
+                    std::ostringstream ss;
+                    if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
+                        ss << "<no match>";
+                    } else {
+                        GGML_ASSERT(!input_output.output.groups.empty());
+                        std::vector<std::string> parts;
+                        for (const auto & g : m->groups) {
+                            parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
+                        }
+                        ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
+                    }
+                    return ss.str();
+                };
+                std::cout << "    Expected: " << match_to_str(input_output.output) << '\n';
+                std::cout << "         Got: " << match_to_str(m) << '\n';
+                std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
+
+                throw std::runtime_error("Test failed");
+            }
+        }
+    };
+    test({
+        "a",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
+        }
+    });
+    test({
+        "abcd",
+        {
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"d", {}},
+            {"bcd", {}},
+            {"cde", {}},
+            {"cd", {}},
+            {"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
+            {"abbie", {}},
+            {"", {}},
+        }
+    });
+    test({
+        ".*?ab",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+        }
+    });
+    test({
+        "a.*?b",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"d", {}},
+            {"b", {}},
+        }
+    });
+    test({
+        "ab(?:cd){2,4}ef",
+        {
+            // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abcde", {}},
+            {"abcdef", {}},
+            {"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
+            {"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
+            {"abcdcdcdcdcdef", {}},
+            {"abcde", {}},
+            {"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
+        }
+    });
+    test({
+        "a(?:rte| pure )fact",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"fact", {}},
+            {"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
+            {"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
+            {"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
+            {"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
+            {"" , {}},
+            {"pure", {}},
+            {"pure fact", {}},
+        }
+    });
+    test({
+        "abc",
+        {
+            {" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"b", {}},
+            {"c", {}},
+            {"", {}},
+        }
+    });
+
+    test({
+        "(?:abc)?\\s*def",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
+            {" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+        }
+    });
+
+    test({
+        "a+b",
+        {
+            {"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+        }
+    });
+
+    test({
+        "(?:"
+            "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
+            "("                          // match 2 (open_tag)
+                "<tool_call>"
+                "|<function_call>"
+                "|<tool>"
+                "|<tools>"
+                "|<response>"
+                "|<json>"
+                "|<xml>"
+                "|<JSON>"
+            ")?"
+            "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
+        ")"
+        "|<function=([^>]+)>"            // match 4 (function name)
+        "|<function name=\"([^\"]+)\">", // match 5 (function name again)
+        {
+            {"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
+            {"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
+            {"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
+            {"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
+            {"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
+            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
+            {"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
+            {"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
+            {"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
+            {"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
+
+        }
+    });
+}
+
+static void test_regex_to_reversed_partial_regex() {
+    printf("[%s]\n", __func__);
+
+    assert_equals<std::string>(
+        "((?:(?:c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abc"));
+
+    assert_equals<std::string>(
+        "(a+)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a+"));
+
+    assert_equals<std::string>(
+        "(a*)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a*"));
+
+    assert_equals<std::string>(
+        "(a?)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a?"));
+
+    assert_equals<std::string>(
+        "([a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]"));
+
+    assert_equals<std::string>(
+        "((?:\\w+)?[a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]\\w+"));
+
+    assert_equals<std::string>(
+        "((?:a|b))[\\s\\S]*",
+        regex_to_reversed_partial_regex("(?:a|b)"));
+    assert_equals<std::string>(
+        "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abcd"));
+    assert_equals<std::string>(
+        "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+        regex_to_reversed_partial_regex("a*b"));
+    assert_equals<std::string>(
+        "((?:(?:b)?a)?.*)[\\s\\S]*",
+        regex_to_reversed_partial_regex(".*?ab"));
+    assert_equals<std::string>(
+        "((?:(?:b)?.*)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a.*?b"));
+    assert_equals<std::string>(
+        "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc)d"));
+    assert_equals<std::string>(
+        "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc|de)"));
+    assert_equals<std::string>(
+        "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("ab{2,4}c"));
+}
+
+int main() {
+    test_regex_to_reversed_partial_regex();
+    test_regex();
+    std::cout << "All tests passed.\n";
+}
index d81579eb1028f60aa719ec4edae00e6b206b3e2f..f32f3c86aad2c992b69f69e07fbe08c41f96e85b 100644 (file)
@@ -1429,7 +1429,7 @@ struct server_slot {
                 pos = text.find(word, from_pos);
             } else {
                 // otherwise, partial stop
-                pos = find_partial_stop_string(word, text);
+                pos = string_find_partial_stop(text, word);
             }
 
             if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {