]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common/grammar : replace problematic backtracking regex `[\s\S]*` (#18342)
authorAldehir Rojas <redacted>
Sat, 3 Jan 2026 22:02:43 +0000 (16:02 -0600)
committerGitHub <redacted>
Sat, 3 Jan 2026 22:02:43 +0000 (16:02 -0600)
* grammar : add support for std::regex_search() with trigger patterns

* common : update hermes2 pro trigger to search instead of match

* common : use regex_search with anchoring for partial matching

* common : adjust regex partial tests to use new pattern

* grammar : check pattern directly instead of adding a type

* common : adjust existing patterns to match new semantics

common/chat.cpp
common/regex-partial.cpp
common/sampling.cpp
src/llama-grammar.cpp
src/llama-grammar.h
tests/test-regex-partial.cpp

index b98ab21ce1cf21e37e1c09245e1c64bd518c7c45..22e527bab882d06882c55a38457f55453c56c1af 100644 (file)
@@ -2065,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
             // Trigger on tool calls that appear in the commentary channel
             data.grammar_triggers.push_back({
                 COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
-                "<\\|channel\\|>(commentary|analysis) to"
+                "<\\|channel\\|>(?:commentary|analysis) to"
             });
 
             // Trigger tool calls that appear in the role section, either at the
@@ -2398,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
                 (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
             // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
             data.grammar_triggers.push_back({
-                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
                 // If thinking_forced_open, then we capture the </think> tag in the grammar,
                 // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
-                std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
+                std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + (
                     "\\s*("
                     "(?:<tool_call>"
                     "|<function"
                     "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
                     "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
                     ")"
-                    ")[\\s\\S]*"
+                    ")"
                 ),
             });
             data.preserved_tokens = {
index 4bff6b66336e246283e133cb929960722a6e9cd3..e667a209e9877c978da2fce13b588b013cc5fcb4 100644 (file)
@@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
         return res;
     }
     std::match_results<std::string::const_reverse_iterator> srmatch;
-    if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
+    if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
         auto group = srmatch[1].str();
         if (group.length() != 0) {
             auto it = srmatch[1].second.base();
@@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
   to see if a string ends with a partial regex match, but but it's not in std::regex yet.
   Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
 
-  - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
-  - /a|b/ -> (a|b).*
+  - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
+  - /a|b/ -> ^(a|b)
   - /a*?/ -> error, could match ""
-  - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
-  - /.*?ab/ -> ((?:b)?a).* (merge .*)
-  - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
-  - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
-  - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
-  - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
+  - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
+  - /.*?ab/ -> ^((?:b)?a) (omit .*)
+  - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
+  - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
+  - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
+  - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
 
-  The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
-  (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
+  The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
+  All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
 */
 std::string regex_to_reversed_partial_regex(const std::string & pattern) {
     auto it = pattern.begin();
@@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
             }
         }
 
-        // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
+        // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
         // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
         // We'll do the outermost capturing group and final .* in the enclosing function.
         std::vector<std::string> res_alts;
@@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
         throw std::runtime_error("Unmatched '(' in pattern");
     }
 
-    return "(" + res + ")[\\s\\S]*";
+    return "^(" + res + ")";
 }
index c66f935c65c53b5b1ed97ec1d6d88afe6d8bf74e..68e36e8744064c0826c111dbfd34c43c8ddb0e75 100644 (file)
@@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 #endif // LLAMA_USE_LLGUIDANCE
     } else {
         std::vector<std::string> trigger_patterns;
-        std::vector<std::string> patterns_anywhere;
         std::vector<llama_token> trigger_tokens;
         for (const auto & trigger : params.grammar_triggers) {
             switch (trigger.type) {
                 case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
                 {
                     const auto & word = trigger.value;
-                    patterns_anywhere.push_back(regex_escape(word));
+                    trigger_patterns.push_back(regex_escape(word));
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
                 {
-                    patterns_anywhere.push_back(trigger.value);
+                    trigger_patterns.push_back(trigger.value);
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
                 {
-                    trigger_patterns.push_back(trigger.value);
+                    const auto & pattern = trigger.value;
+                    std::string anchored = "^$";
+                    if (!pattern.empty()) {
+                        anchored = (pattern.front() != '^' ? "^" : "")
+                            + pattern
+                            + (pattern.back() != '$' ? "$" : "");
+                    }
+                    trigger_patterns.push_back(anchored);
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
             }
         }
 
-        if (!patterns_anywhere.empty()) {
-            trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
-        }
-
         std::vector<const char *> trigger_patterns_c;
         trigger_patterns_c.reserve(trigger_patterns.size());
         for (const auto & regex : trigger_patterns) {
index 75d5d750c3998a74a86153955961440bf81f190a..64ea2fd00a9ac90cf39722d6b4a247290ce5bb8c 100644 (file)
@@ -369,6 +369,44 @@ static void print_rule(
     fprintf(file, "\n");
 }
 
+//
+// Regex utilities
+//
+
+size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
+    auto find_start_pos = [](const std::smatch & match) {
+        // get from the first matched capturing group to the end of the string
+        size_t start = std::string::npos;
+        for (auto i = 1u; i < match.size(); i++) {
+            if (match.length(i) > 0) {
+                start = match.position(i);
+                break;
+            }
+        }
+        if (start == std::string::npos) {
+            start = match.position(0);
+        }
+        return start;
+    };
+
+    if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
+        // match against the entire input
+        std::smatch match;
+        if (std::regex_match(input, match, regex)) {
+            return find_start_pos(match);
+        }
+    }
+
+    // search anywhere
+    std::smatch match;
+    if (std::regex_search(input, match, regex)) {
+        return find_start_pos(match);
+    }
+
+    return std::string::npos;
+}
+
+
 //
 // implementation
 //
@@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
             grammar.trigger_buffer += piece;
 
-            std::smatch match;
             for (const auto & trigger_pattern : grammar.trigger_patterns) {
-                if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
+                auto start = trigger_pattern.find(grammar.trigger_buffer);
+                if (start != std::string::npos) {
                     grammar.awaiting_trigger = false;
-                    // get from the first matched capturing group to the end of the string
-                    size_t start = std::string::npos;
-                    for (auto i = 1u; i < match.size(); i++) {
-                        if (match.length(i) > 0) {
-                            start = match.position(i);
-                            break;
-                        }
-                    }
-                    if (start == std::string::npos) {
-                        start = match.position(0);
-                    }
 
                     // replay tokens that overlap with [start, end)
                     for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
index a4c978ac1154549ba4df8a4590d2ec214c94a044..b5a0e588e903be1f3277026ebfbc81cb966216f2 100644 (file)
@@ -119,6 +119,8 @@ struct llama_grammar_parser {
 struct llama_grammar_trigger_pattern {
     std::string pattern;
     std::regex  regex;
+
+    size_t find(const std::string & input) const;
 };
 
 struct llama_grammar {
index ffad1897860a59cb70c4c9daf80e98291797917f..70af6d75a1532ab51239956b9080855788b0f022 100644 (file)
@@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() {
     printf("[%s]\n", __func__);
 
     assert_equals<std::string>(
-        "((?:(?:c)?b)?a)[\\s\\S]*",
+        "^((?:(?:c)?b)?a)",
         regex_to_reversed_partial_regex("abc"));
 
     assert_equals<std::string>(
-        "(a+)[\\s\\S]*",
+        "^(a+)",
         regex_to_reversed_partial_regex("a+"));
 
     assert_equals<std::string>(
-        "(a*)[\\s\\S]*",
+        "^(a*)",
         regex_to_reversed_partial_regex("a*"));
 
     assert_equals<std::string>(
-        "(a?)[\\s\\S]*",
+        "^(a?)",
         regex_to_reversed_partial_regex("a?"));
 
     assert_equals<std::string>(
-        "([a-z])[\\s\\S]*",
+        "^([a-z])",
         regex_to_reversed_partial_regex("[a-z]"));
 
     assert_equals<std::string>(
-        "((?:\\w+)?[a-z])[\\s\\S]*",
+        "^((?:\\w+)?[a-z])",
         regex_to_reversed_partial_regex("[a-z]\\w+"));
 
     assert_equals<std::string>(
-        "((?:a|b))[\\s\\S]*",
+        "^((?:a|b))",
         regex_to_reversed_partial_regex("(?:a|b)"));
     assert_equals<std::string>(
-        "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+        "^((?:(?:(?:d)?c)?b)?a)",
         regex_to_reversed_partial_regex("abcd"));
     assert_equals<std::string>(
-        "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+        "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
         regex_to_reversed_partial_regex("a*b"));
     assert_equals<std::string>(
-        "((?:(?:b)?a)?.*)[\\s\\S]*",
+        "^((?:(?:b)?a)?.*)",
         regex_to_reversed_partial_regex(".*?ab"));
     assert_equals<std::string>(
-        "((?:(?:b)?.*)?a)[\\s\\S]*",
+        "^((?:(?:b)?.*)?a)",
         regex_to_reversed_partial_regex("a.*?b"));
     assert_equals<std::string>(
-        "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+        "^((?:(?:d)?(?:(?:c)?b))?a)",
         regex_to_reversed_partial_regex("a(bc)d"));
     assert_equals<std::string>(
-        "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+        "^((?:(?:(?:c)?b|(?:e)?d))?a)",
         regex_to_reversed_partial_regex("a(bc|de)"));
     assert_equals<std::string>(
-        "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+        "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
         regex_to_reversed_partial_regex("ab{2,4}c"));
 }