]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`tool-call`: fix Qwen 2.5 Coder support, add micro benchmarks, support trigger patter...
authorOlivier Chafik <redacted>
Wed, 5 Mar 2025 13:05:13 +0000 (13:05 +0000)
committerGitHub <redacted>
Wed, 5 Mar 2025 13:05:13 +0000 (13:05 +0000)
* sampler: turn lazy grammar trigger words to regexes

* add scripts/tool_bench.sh & .py

* constrain llama json output regardless of function name if matches at beginning

* update relaxed newline space rule in grammar tests

* support add_generation_prompt query parameter (useful for /apply_template)

* Update src/llama-grammar.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
26 files changed:
README.md
common/chat.cpp
common/common.cpp
common/common.h
common/json-schema-to-grammar.cpp
common/json-schema-to-grammar.h
common/sampling.cpp
examples/json_schema_to_grammar.py
examples/server/public_legacy/json-schema-to-grammar.mjs
examples/server/server.cpp
examples/server/tests/unit/test_tool_call.py [changed mode: 0644->0755]
examples/server/tests/utils.py
examples/server/utils.hpp
include/llama.h
models/templates/README.md
requirements.txt
requirements/requirements-all.txt
requirements/requirements-tool_bench.txt [new file with mode: 0644]
scripts/fetch_server_test_models.py
scripts/tool_bench.py [new file with mode: 0755]
scripts/tool_bench.sh [new file with mode: 0755]
src/llama-grammar.cpp
src/llama-grammar.h
src/llama-sampling.cpp
tests/test-chat.cpp
tests/test-json-schema-to-grammar.cpp

index ffb59a4229026d7a928daac5b054efdd74bc6a04..d73b0495d8e36f0300e00c6529edae0f34a4b00f 100644 (file)
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
 
 - **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427
 - **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
-- Universal tool call support in `llama-server`: https://github.com/ggml-org/llama.cpp/pull/9639
+- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
 - Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
 - Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
index 9ebe4c5784cbcf98014cde575fc9295b6fa08cbc..1b10219ccab04bdd121a401307ed5ca389a324ba 100644 (file)
@@ -449,12 +449,6 @@ std::string common_chat_format_name(common_chat_format format) {
     }
 }
 
-const common_grammar_options grammar_options {
-    /* .dotall = */ false,
-    /* .compact_spaces = */ false,
-    // /* .compact_spaces = */ true,
-};
-
 static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
     // // https://json.nlohmann.me/features/parsing/sax_interface/
     struct json_error_locator : public nlohmann::json_sax<json> {
@@ -500,6 +494,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
     }
 }
 
+static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
+    auto expected_it = expected.begin();
+    auto tmp_it = it;
+    while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
+        ++tmp_it;
+        ++expected_it;
+    }
+    if (expected_it == expected.end()) {
+        it = tmp_it;
+        return true;
+    }
+    return false;
+}
+
+static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
+    std::smatch match;
+    if (std::regex_match(it, end, match, expected)) {
+        it = match.suffix().first;
+        return match;
+    }
+    return std::nullopt;
+}
+
+static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
+    while (it != end && std::isspace(*it)) {
+        ++it;
+    }
+}
 
 /**
  * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
@@ -509,7 +531,8 @@ static common_chat_msg parse_json_tool_calls(
     const std::string& input,
     const std::optional<std::regex> & trigger_opt,
     const std::regex & function_regex,
-    const std::regex & close_regex) {
+    const std::regex & close_regex,
+    bool allow_raw_python = false) {
     std::smatch match;
 
     common_chat_msg result;
@@ -540,14 +563,19 @@ static common_chat_msg parse_json_tool_calls(
         it = rit->suffix().first;
 
         json arguments;
-        if (!parse_json(it, end, arguments)) {
+        if (parse_json(it, end, arguments)) {
+            if (!std::regex_search(it, end, match, close_regex)) {
+                throw std::runtime_error("Malformed input, missing closing pattern: " + input);
+            }
+            it = match.suffix().first;
+            result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
+        } else {
+            if (allow_raw_python && name == "python") {
+                result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
+                break;
+            }
             throw std::runtime_error("Failed to parse json tool call arguments: " + input);
         }
-        if (!std::regex_search(it, end, match, close_regex)) {
-            throw std::runtime_error("Malformed input, missing closing pattern: " + input);
-        }
-        it = match.suffix().first;
-        result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
     }
 
     if (!result.tool_calls.empty()) {
@@ -559,29 +587,29 @@ static common_chat_msg parse_json_tool_calls(
     return result;
 }
 
+static common_chat_tool_call process_tool_call(const json & tool_call) {
+    const auto & arguments = tool_call.at("arguments");
+    return {
+        /* .name = */ tool_call.at("name"),
+        /* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
+        /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
+    };
+}
 static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
     auto content_end = input.find(prefix);
     size_t tc_start = std::string::npos;
 
     common_chat_msg result;
     result.role = "assistant";
-    const auto process_tool_calls = [&](const json & tool_calls) {
-        for (const auto & tool_call : tool_calls) {
-            const auto & arguments = tool_call.at("arguments");
-            result.tool_calls.push_back({
-                tool_call.at("name"),
-                arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
-                tool_call.contains("id") ? tool_call.at("id") : "",
-            });
-        }
-    };
     if (content_end == std::string::npos) {
         result.content = input;
     } else {
         tc_start = content_end + prefix.size() - rstrip_prefix;
         result.content = input.substr(0, content_end);
         auto tool_calls = json::parse(input.substr(tc_start));
-        process_tool_calls(tool_calls);
+        for (const auto & tool_call : tool_calls) {
+            result.tool_calls.emplace_back(process_tool_call(tool_call));
+        }
     }
     return result;
 }
@@ -700,7 +728,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
     data.grammar_lazy = false;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         builder.add_schema("root", schema);
-    }, grammar_options);
+    });
 
     auto tweaked_messages = common_chat_template::add_system(
         inputs.messages,
@@ -770,8 +798,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
             schema["maxItems"] = 1;
         }
         builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
-    }, grammar_options);
-    data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
+    });
+    data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
+    data.preserved_tokens = {
+        "[TOOL_CALLS]",
+    };
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
     return data;
@@ -813,14 +844,18 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
             schema["maxItems"] = 1;
         }
         builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
-    }, grammar_options);
-    data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
+    });
+    data.grammar_triggers.push_back({
+        COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
+        "<|START_ACTION|>",
+    });
     data.preserved_tokens = {
+        "<|START_ACTION|>",
+        "<|END_ACTION|>",
         "<|START_RESPONSE|>",
         "<|END_RESPONSE|>",
         "<|START_THINKING|>",
         "<|END_THINKING|>",
-        "<|END_ACTION|>",
     };
     auto adjusted_messages = json::array();
     for (const auto & msg : inputs.messages) {
@@ -840,9 +875,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
     return data;
 }
 static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
-    static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
-    static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
-    static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
+    static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
+    static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
+    static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
 
     std::smatch match;
 
@@ -945,23 +980,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
                 builder.add_rule(
                     name + "-call",
                     "\"{\" space "
-                    "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
-                    "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
-                        builder.add_schema(name + "-args", parameters) +
-                    " \"}\""));
-            data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
+                    "( \"\\\"type\\\"\"       space \":\" space \"\\\"function\\\"\"     space \",\" space )? "
+                    "  \"\\\"name\\\"\"       space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
+                    "  \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
+                    "\"}\" space"));
+        });
+        // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
+        data.grammar_triggers.push_back({
+            COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+            "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
         });
-        data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
-        data.grammar_triggers.push_back({"{\n  \"name\":", /* .at_start = */ true});
-        data.grammar_triggers.push_back({"{\n    \"name\":", /* .at_start = */ true});
-        data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
-        data.grammar_triggers.push_back({"{\n  \"type\": \"function\"", /* .at_start = */ true});
-        data.grammar_triggers.push_back({"{\n    \"type\": \"function\"", /* .at_start = */ true});
         if (!builtin_tools.empty()) {
-            data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
+            data.preserved_tokens.push_back("<|python_tag|>");
         }
+        // Allow a few empty lines on top of the usual constrained json schema space rule.
         builder.add_rule("root", string_join(tool_rules, " | "));
-    }, grammar_options);
+    });
     data.additional_stops.push_back("<|eom_id|>");
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
         {"tools_in_user_message", false},
@@ -974,33 +1009,33 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
 }
 static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
     // TODO: tighten & simplify the parser, don't accept leading text context.
-    static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
-    static std::regex close_regex("\\}");
-    static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
+    static std::regex function_regex(
+        "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
+    static std::regex close_regex("\\}\\s*");
+    static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
 
     if (with_builtin_tools) {
         std::smatch match;
         if (std::regex_match(input, match, builtin_call_regex)) {
-            auto name = match[1].str();
-            auto raw_args = match[2].str();
-
-            // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
-            auto it_eq = raw_args.find('=');
-            auto arg_name = raw_args.substr(0, it_eq);
-            auto arg_value_str = raw_args.substr(it_eq + 1);
-            auto arg_value = json::parse(arg_value_str);
-
-            common_chat_msg msg;
-            msg.role = "assistant";
-            msg.content = match.prefix().str();
-            msg.tool_calls.push_back({
-                /* .name = */ name,
-                /* .arguments = */ (json {
-                    {arg_name, arg_value},
-                }).dump(),
-                /* .id = */ "",
-            });
-            return msg;
+            try {
+                auto name = match[1].str();
+                auto arg_name = match[2].str();
+                auto arg_value_str = match[3].str();
+                auto arg_value = json::parse(arg_value_str);
+
+                common_chat_msg msg;
+                msg.role = "assistant";
+                msg.tool_calls.push_back({
+                    /* .name = */ name,
+                    /* .arguments = */ (json {
+                        {arg_name, arg_value},
+                    }).dump(),
+                    /* .id = */ "",
+                });
+                return msg;
+            } catch (const std::exception & e) {
+                LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
+            }
         }
     }
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
@@ -1017,10 +1052,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
                 builder.resolve_refs(parameters);
-                auto args_rule = builder.add_schema(name + "-args", parameters);
                 tool_rules.push_back(builder.add_rule(name + "-call",
                     "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
-                    "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
+                    "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
+                    "\"```<|tool▁call▁end|>\""));
             });
             // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
             // so we accept common variants (then it's all constrained)
@@ -1029,18 +1064,20 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
                 "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
                 "\"<|tool▁calls▁end|>\""
                 " space");
-            data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
-            data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
-            data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
-            data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
             data.preserved_tokens = {
                 "<think>",
                 "</think>",
+                "<|tool▁calls▁begin|>",
+                "<|tool▁call▁begin|>",
                 "<|tool▁sep|>",
-                "<|tool▁calls▁end|",
                 "<|tool▁call▁end|>",
+                "<|tool▁calls▁end|",
             };
-        }, grammar_options);
+        });
     }
     auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
 
@@ -1129,8 +1166,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
                 schema["maxItems"] = 1;
             }
             builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
-        }, grammar_options);
-        data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
+        });
+        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
+        data.preserved_tokens = {
+            " functools[",
+        };
         data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
     } else {
         data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -1158,11 +1198,28 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
                 auto parameters = function.at("parameters");
                 builder.resolve_refs(parameters);
                 auto args_rule = builder.add_schema(name + "-args", parameters);
-                first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
+                first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
                 subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
-                data.grammar_triggers.push_back({name, /* .at_start = */ true});
-                data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
+                data.grammar_triggers.push_back({
+                    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+                    regex_escape(name + "\n"),
+                });
+                data.grammar_triggers.push_back({
+                    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+                    regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
+                });
+                data.grammar_triggers.push_back({
+                    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
+                    regex_escape(">>>" + name + "\n"),
+                });
+                data.grammar_triggers.push_back({
+                    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
+                    ">>>assistant<|end_header_id|>\n" + name,
+                });
             });
+            data.preserved_tokens = {
+                "<|end_header_id|>",
+            };
             auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
             if (inputs.parallel_tool_calls) {
                 auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
@@ -1171,34 +1228,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
                 builder.add_rule("root", first_rule);
             }
 
-        }, grammar_options);
+        });
     }
     return data;
 }
 
-static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
-    auto expected_it = expected.begin();
-    auto tmp_it = it;
-    while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
-        ++tmp_it;
-        ++expected_it;
-    }
-    if (expected_it == expected.end()) {
-        it = tmp_it;
-        return true;
-    }
-    return false;
-}
-
 static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
-    static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
+    static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
     static std::regex close_regex(R"($|(?=>>>))");
 
     std::string content;
     auto it = input.begin();
     const auto end = input.end();
 
-    if (consume(it, end, "all\n")) {
+    if (parse_literal(it, end, "all\n")) {
         std::smatch match;
         if (std::regex_search(it, end, match, function_regex)) {
             auto fun_it = match.prefix().second;
@@ -1213,7 +1256,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
     }
     // TODO: tighten & simplify.
     try {
-        auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
+        auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
         res.content = content + res.content;
         return res;
     } catch (const std::exception & e) {
@@ -1266,12 +1309,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
         });
         if (has_raw_python) {
             tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
-            data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
+            data.preserved_tokens.push_back("<|python_tag|>");
         }
         auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
         builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
-        data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
-    }, grammar_options);
+        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
+    });
 
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     // TODO: if (has_raw_python)
@@ -1306,6 +1350,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
     data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
+        std::vector<std::string> tool_call_alts;
         foreach_function(inputs.tools, [&](const json & tool) {
             const auto & function = tool.at("function");
             std::string name = function.at("name");
@@ -1319,57 +1364,173 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
                 }},
                 {"required", json::array({"name", "arguments"})},
             }));
+            tool_call_alts.push_back(builder.add_rule(
+                name + "-function-tag",
+                "\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
+                builder.add_schema(name + "-args", parameters) + " "
+                "\"</function>\" space"));
+
+            data.grammar_triggers.push_back({
+                COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
+                "<function=" + name + ">",
+            });
+            auto escaped_name = regex_escape(name);
+            data.grammar_triggers.push_back({
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
+                "<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
+            });
         });
-        auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
+        auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
+        std::vector<std::string> alt_tags {
+            any_tool_call,
+            "\"<tool_call>\" space "     + any_tool_call + " \"</tool_call>\"",
+            // The rest is just to accommodate common "good bad" outputs.
+            "\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
+            "\"<response>\"  space "     + any_tool_call + " \"</response>\"",
+            "\"<tools>\"     space "     + any_tool_call + " \"</tools>\"",
+            "\"<json>\"      space "     + any_tool_call + " \"</json>\"",
+            "\"<xml>\"      space "     + any_tool_call + " \"</xml>\"",
+            "\"<JSON>\"      space "     + any_tool_call + " \"</JSON>\"",
+        };
+        auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
+        tool_call_alts.push_back(wrappable_tool_call);
+        tool_call_alts.push_back(
+            "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
+        auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
         builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
-        data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
-        data.preserved_tokens = { "</tool_call>" };
-    }, grammar_options);
+        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
+        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
+        // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
+        data.grammar_triggers.push_back({
+            COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+            "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
+        });
+        data.preserved_tokens = {
+            "<tool_call>",
+            "</tool_call>",
+            "<function",
+            "<tools>",
+            "</tools>",
+            "<response>",
+            "</response>",
+            "<function_call>",
+            "</function_call>",
+            "<json>",
+            "</json>",
+            "<JSON>",
+            "</JSON>",
+            "```",
+            "```json",
+            "```xml",
+        };
+    });
 
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
     return data;
 }
-static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
+static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) {
+    const static std::regex open_regex(
+        "(?:"
+        "(```(?:xml|json)?\\n\\s*)?"         // match 1 (block_start)
+        "(<tool_call>"                   // match 2 (open_tag)
+        "|<function_call>"
+        "|<tool>"
+        "|<tools>"
+        "|<response>"
+        "|<json>"
+        "|<xml>"
+        "|<JSON>"
+        ")?"
+        "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)"    // match 3 (named tool call + rest)
+        ")"
+        "|"
+        "(?:<function=([^>]+)>"            // match 4 (function name)
+        "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
+        "([\\s\\S]*)"                   // match 6 (function arguments + rest)})"
+    );
+
     try {
-        std::regex start_pattern(R"([\n\s]*<tool_call>)");
-        std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
-        std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
 
         common_chat_msg msg;
         msg.role = "assistant";
 
-        auto end = input.end();
-        std::sregex_iterator rend;
-        std::sregex_iterator rit(input.begin(), end, start_pattern);
-        if (rit == rend) {
-            msg.content = input;
-            return msg;
-        }
-
-        msg.content = rit->prefix();
+        std::string::const_iterator it = input.begin();
+        const std::string::const_iterator end = input.end();
+        std::smatch match;
 
-        auto it = rit->suffix().first;
         while (it != end) {
-            json call;
-            if (!parse_json(it, end, call)) {
-                throw std::runtime_error("Failed to parse json tool call");
-            }
-            const auto & arguments = call.at("arguments");
-            msg.tool_calls.push_back({
-                call.at("name"),
-                arguments.dump(),
-                // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
-                /* id= */ "",
-            });
-            rit = {it, end, middle_pattern};
-            if (rit != rend) {
-                it = rit->suffix().first;
-            } else {
-                rit = {it, end, end_pattern};
-                if (rit == rend) {
-                    throw std::runtime_error("Malformed input, missing </tool_call>");
+            if (std::regex_search(it, end, match, open_regex)) {
+                // Add content before the match
+                msg.content += std::string(it, match[0].first);
+
+                auto block_start = match[1].str();
+                std::string block_end = block_start.empty() ? "" : "```";
+
+                auto open_tag = match[2].str();
+                std::string close_tag;
+
+                if (match[3].matched) {
+                    close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
+                    auto json_it = match[3].first;
+                    json tool_call;
+                    if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
+
+                        msg.tool_calls.emplace_back(process_tool_call(tool_call));
+                        it = json_it;  // Move iterator past parsed JSON
+
+                        // Handle close tags
+                        consume_spaces(it, end);
+                        if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
+                            throw std::runtime_error("Failed to parse closing tag");
+                        }
+                        consume_spaces(it, end);
+                        if (!block_end.empty() && !parse_literal(it, end, block_end)) {
+                            throw std::runtime_error("Failed to parse block end");
+                        }
+                        consume_spaces(it, end);
+                    } else {
+                        // Not a valid tool call, treat as content
+                        msg.content += std::string(match[0].first, match[0].second);
+                        it = match[0].second;
+                    }
+                } else {
+                    auto function_name = match[4].str();
+                    if (function_name.empty()) {
+                        function_name = match[5].str();
+                    }
+                    GGML_ASSERT(!function_name.empty());
+
+                    close_tag = "</function>";
+                    // Start parsing from after the opening tags
+                    auto json_it = match[6].first;
+                    json arguments;
+                    if (parse_json(json_it, end, arguments)) {
+                        msg.tool_calls.emplace_back(process_tool_call({
+                            {"name", function_name},
+                            {"arguments", arguments},
+                        }));
+                        it = json_it;  // Move iterator past parsed JSON
+
+                        // Handle close tags
+                        consume_spaces(it, end);
+                        if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
+                            throw std::runtime_error("Failed to parse closing tag");
+                        }
+                        consume_spaces(it, end);
+                        if (!block_end.empty() && !parse_literal(it, end, block_end)) {
+                            throw std::runtime_error("Failed to parse block end");
+                        }
+                        consume_spaces(it, end);
+                    } else {
+                        // Not a valid tool call, treat as content
+                        msg.content += std::string(match[0].first, match[0].second);
+                        it = match[0].second;
+                    }
                 }
+            } else {
+                // Add remaining content
+                msg.content += std::string(it, end);
                 break;
             }
         }
index d2b0d50e3ee3975e90ff12abcb2bedaef041ef21..6448b7b03d6d2614b3f7da8b5e7b0cf1a6ab0879 100644 (file)
@@ -10,7 +10,6 @@
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
-#include "json-schema-to-grammar.h"
 #include "llama.h"
 
 #include <algorithm>
@@ -483,6 +482,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
     s = std::move(builder);
 }
 
+std::string regex_escape(const std::string & s) {
+    static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
+    return std::regex_replace(s, special_chars, "\\$0");
+}
+
 std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
     std::ostringstream result;
     for (size_t i = 0; i < values.size(); ++i) {
@@ -2026,3 +2030,25 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
     return result;
 }
 
+template <>
+json common_grammar_trigger::to_json() const {
+    json out {
+        {"type", (int) type},
+        {"value", value},
+    };
+    if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+        out["token"] = (int) token;
+    }
+    return out;
+}
+
+template <>
+common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
+    common_grammar_trigger out;
+    out.type = (common_grammar_trigger_type) in.at("type").get<int>();
+    out.value = in.at("value").get<std::string>();
+    if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+        out.token = (llama_token) in.at("token").get<int>();
+    }
+    return out;
+}
index f4b4a96fb68ff131ba690c78596a932a2c7c5943..733f7f1c8d662a7048235d7f9dd7a44356c961e3 100644 (file)
@@ -110,9 +110,21 @@ enum common_conversation_mode {
     COMMON_CONVERSATION_MODE_AUTO     = 2,
 };
 
+enum common_grammar_trigger_type {
+    COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
+    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
+    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
+    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+};
+
 struct common_grammar_trigger {
-    std::string word;
-    bool at_start;
+    common_grammar_trigger_type type;
+    std::string value;
+    llama_token token = LLAMA_TOKEN_NULL;
+
+    // T can only be nlohmann::ordered_json
+    template <class T> T to_json() const;
+    template <class T> static common_grammar_trigger from_json(const T & in);
 };
 
 // sampling parameters
@@ -163,8 +175,7 @@ struct common_params_sampling {
 
     std::string                         grammar; // optional BNF-like grammar to constrain sampling
     bool                                grammar_lazy = false;
-    std::vector<common_grammar_trigger> grammar_trigger_words;  // optional trigger words to trigger lazy grammar
-    std::vector<llama_token>            grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
+    std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
     std::set<llama_token>               preserved_tokens;
 
     std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@@ -458,6 +469,8 @@ std::string string_repeat(const std::string & str, size_t n);
 
 void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
 
+std::string regex_escape(const std::string & s);
+
 template<class T>
 static std::vector<T> string_split(const std::string & str, char delim) {
     static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
index 3ebcc3d9fbbf8fe718c66a4f8a348b54bbae8f1a..90679822571204c02c66f0c828c7872febd61f3b 100644 (file)
@@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
     throw std::runtime_error("At least one of min_value or max_value must be set");
 }
 
-const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
+const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
 
 struct BuiltinRule {
     std::string content;
@@ -764,11 +764,10 @@ private:
 public:
     SchemaConverter(
         const std::function<json(const std::string &)> & fetch_json,
-        bool dotall,
-        bool compact_spaces)
+        bool dotall)
           : _fetch_json(fetch_json), _dotall(dotall)
     {
-        _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
+        _rules["space"] = SPACE_RULE;
     }
 
     void resolve_refs(json & schema, const std::string & url) {
@@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
 }
 
 std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
-    SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
+    SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
     common_grammar_builder builder {
         /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
             return converter._add_rule(name, rule);
index 62a3b0a4477cca0f9558183cec9832ed2c421973..4613f5d9f910ce589ba425a3619ffe353283fcea 100644 (file)
@@ -16,7 +16,6 @@ struct common_grammar_builder {
 
 struct common_grammar_options {
     bool dotall = false;
-    bool compact_spaces = false;
 };
 
 std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
index 972fbe9dab465984e4502338a41ec8a6330463df..baf22066dca157b2db6543b614baa7c43e6d3c15 100644 (file)
@@ -160,16 +160,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
 #endif // LLAMA_USE_LLGUIDANCE
     } else {
-        std::vector<const char *> trigger_words;
-        trigger_words.reserve(params.grammar_trigger_words.size());
-        for (const auto & str : params.grammar_trigger_words) {
-            trigger_words.push_back(str.word.c_str());
+        std::vector<std::string> patterns_at_start;
+        std::vector<std::string> patterns_anywhere;
+        std::vector<llama_token> trigger_tokens;
+        for (const auto & trigger : params.grammar_triggers) {
+            switch (trigger.type) {
+                case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
+                {
+                    const auto & word = trigger.value;
+                    patterns_anywhere.push_back(regex_escape(word));
+                    break;
+                }
+                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
+                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
+                {
+                    const auto & pattern = trigger.value;
+                    (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
+                    break;
+                }
+                case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
+                {
+                    const auto token = trigger.token;
+                    trigger_tokens.push_back(token);
+                    break;
+                }
+                default:
+                    GGML_ASSERT(false && "unknown trigger type");
+            }
+        }
+
+        std::vector<std::string> trigger_patterns;
+        if (!patterns_at_start.empty()) {
+            trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
+        }
+        if (!patterns_anywhere.empty()) {
+            trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
+        }
+
+        std::vector<const char *> trigger_patterns_c;
+        trigger_patterns_c.reserve(trigger_patterns.size());
+        for (const auto & regex : trigger_patterns) {
+            trigger_patterns_c.push_back(regex.c_str());
         }
 
         grmr = params.grammar_lazy
-             ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
-                                               trigger_words.data(), trigger_words.size(),
-                                               params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
+             ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
+                                                        trigger_patterns_c.data(), trigger_patterns_c.size(),
+                                                        trigger_tokens.data(), trigger_tokens.size())
              :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
     }
 
index fc9f0097f5f8fea03ca7fb83973825105258705b..55f94c0b0a864d596b5b8c48be0fd2c45e9adefe 100755 (executable)
@@ -195,7 +195,7 @@ class BuiltinRule:
         self.deps = deps or []
 
 # Constraining spaces to prevent model "running away".
-SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
+SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
 
 PRIMITIVE_RULES = {
     'boolean'      : BuiltinRule('("true" | "false") space', []),
index e67bb15c176333fc9ebb6af48eab73e4986599bc..f767ce7b7200848599f1a8cf3ce01364a459a253 100644 (file)
@@ -1,5 +1,5 @@
 // WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
-const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
+const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';
 
 function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
   if (minItems === 0 && maxItems === 1) {
index e4f7e43fdcd46931c79f08a213d3691c63dd4bb1..2a526b0e7acdd4b115d627010333d70bce877c15 100644 (file)
@@ -131,9 +131,9 @@ struct slot_params {
             lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
         }
 
-        std::vector<std::string> grammar_trigger_words;
-        for (const auto & trigger : sampling.grammar_trigger_words) {
-            grammar_trigger_words.push_back(trigger.word);
+        auto grammar_triggers = json::array();
+        for (const auto & trigger : sampling.grammar_triggers) {
+            grammar_triggers.push_back(trigger.to_json<json>());
         }
 
         return json {
@@ -170,8 +170,8 @@ struct slot_params {
             {"n_probs",                   sampling.n_probs},
             {"min_keep",                  sampling.min_keep},
             {"grammar",                   sampling.grammar},
-            {"grammar_trigger_words",     grammar_trigger_words},
-            {"grammar_trigger_tokens",    sampling.grammar_trigger_tokens},
+            {"grammar_lazy",              sampling.grammar_lazy},
+            {"grammar_triggers",          grammar_triggers},
             {"preserved_tokens",          sampling.preserved_tokens},
             {"chat_format",               common_chat_format_name(oaicompat_chat_format)},
             {"samplers",                  samplers},
@@ -356,24 +356,6 @@ struct server_task {
         }
 
         {
-            const auto grammar_triggers = data.find("grammar_triggers");
-            if (grammar_triggers != data.end()) {
-                for (const auto & t : *grammar_triggers) {
-                    common_grammar_trigger trigger;
-                    trigger.word = t.at("word");
-                    trigger.at_start = t.at("at_start");
-
-                    auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
-                    if (ids.size() == 1) {
-                        SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
-                        params.sampling.grammar_trigger_tokens.push_back(ids[0]);
-                        params.sampling.preserved_tokens.insert(ids[0]);
-                        continue;
-                    }
-                    SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
-                    params.sampling.grammar_trigger_words.push_back(trigger);
-                }
-            }
             const auto preserved_tokens = data.find("preserved_tokens");
             if (preserved_tokens != data.end()) {
                 for (const auto & t : *preserved_tokens) {
@@ -383,12 +365,38 @@ struct server_task {
                         params.sampling.preserved_tokens.insert(ids[0]);
                     } else {
                         // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
-                        SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
+                        SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
+                    }
+                }
+            }
+            const auto grammar_triggers = data.find("grammar_triggers");
+            if (grammar_triggers != data.end()) {
+                for (const auto & t : *grammar_triggers) {
+                    auto ct = common_grammar_trigger::from_json(t);
+                    if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+                        const auto & word = ct.value;
+                        auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
+                        if (ids.size() == 1) {
+                            auto token = ids[0];
+                            if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
+                                throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
+                            }
+                            SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
+                            common_grammar_trigger trigger;
+                            trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
+                            trigger.value = (llama_token) token;
+                            params.sampling.grammar_triggers.push_back(trigger);
+                        } else {
+                            SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
+                            params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
+                        }
+                    } else {
+                        params.sampling.grammar_triggers.push_back(ct);
                     }
                 }
             }
-            if (params.sampling.grammar_lazy) {
-                GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
+            if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
+                throw std::runtime_error("Error: no triggers set for lazy grammar!");
             }
         }
 
@@ -2045,7 +2053,7 @@ struct server_context {
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
             // Might be better to reject the request with a 400 ?
-            SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
+            SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
             slot.params.n_predict = slot.n_predict;
         }
 
old mode 100644 (file)
new mode 100755 (executable)
index a91a2f3..25bddba
@@ -1,4 +1,12 @@
+#!/usr/bin/env python
 import pytest
+
+# ensure grandparent path is in sys.path
+from pathlib import Path
+import sys
+path = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(path))
+
 from utils import *
 
 server: ServerProcess
@@ -66,15 +74,8 @@ WEATHER_TOOL = {
 }
 
 
-def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
-    global server
-    n_predict = 512
-    # server = ServerPreset.stories15m_moe()
-    server.jinja = True
-    server.n_predict = n_predict
-    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
+def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -83,16 +84,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
         "tool_choice": "required",
         "tools": [tool],
         "parallel_tool_calls": False,
-        "temperature": 0.0,
-        "top_k": 1,
-        "top_p": 1.0,
+        **kwargs,
     })
     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
     choice = res.body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
-    assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
+    assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
     assert expected_function_name == tool_call["function"]["name"]
     actual_arguments = tool_call["function"]["arguments"]
@@ -108,7 +107,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
     ("meta-llama-Llama-3.3-70B-Instruct",             PYTHON_TOOL,          "code"),
 ])
 def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
-    do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
+    global server
+    n_predict = 512
+    # server = ServerPreset.stories15m_moe()
+    server.jinja = True
+    server.n_predict = n_predict
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
 
 
 @pytest.mark.slow
@@ -130,10 +136,17 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      TEST_TOOL,            "success"),
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      PYTHON_TOOL,          "code"),
     ("fireworks-ai-llama-3-firefunction-v2",          TEST_TOOL,            "success"),
-    ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "code"),
+    ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "code"),
 ])
 def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
-    do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
+    global server
+    n_predict = 512
+    # server = ServerPreset.stories15m_moe()
+    server.jinja = True
+    server.n_predict = n_predict
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
 
 
 @pytest.mark.slow
@@ -142,25 +155,33 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     (PYTHON_TOOL,  "code",     "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
 
-    # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
     (TEST_TOOL,    "success",  "bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
     (PYTHON_TOOL,  "code",     "bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
+    (PYTHON_TOOL,  "code",     "bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None),
     (PYTHON_TOOL,  "code",     "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None),
     (PYTHON_TOOL,  "code",     "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"),
 
+    (TEST_TOOL,    "success",  "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M",      None),
+    (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M",      None),
+    (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M",      "chatml"),
+
+    (TEST_TOOL,    "success",  "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M",      None),
+    (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M",      None),
+    (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M",      "chatml"),
+
     (TEST_TOOL,    "success",  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None),
     (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None),
     (PYTHON_TOOL,  "code",     "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
-    (PYTHON_TOOL,  "code",     "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
+    (PYTHON_TOOL,  "code",     "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
-    (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   "chatml"),
+    (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
@@ -176,10 +197,10 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
 
     (TEST_TOOL,    "success",  "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      ("meta-llama/Llama-3.2-3B-Instruct", None)),
     (PYTHON_TOOL,  "code",     "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      ("meta-llama/Llama-3.2-3B-Instruct", None)),
-    (PYTHON_TOOL,  "code",     "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      "chatml"),
-    # TODO: fix these
-    (TEST_TOOL,    "success",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (PYTHON_TOOL,  "code",     "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (PYTHON_TOOL,  "code",     "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      "chatml"),
+
+    (TEST_TOOL,    "success",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (PYTHON_TOOL,  "code",     "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
 def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
@@ -197,7 +218,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -215,7 +236,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
-    assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
+    # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
     assert expected_function_name == tool_call["function"]["name"]
     actual_arguments = tool_call["function"]["arguments"]
@@ -225,13 +246,8 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
         assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
 
 
-def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
-    global server
-    server.jinja = True
-    server.n_predict = n_predict
-    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
-    server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
+def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -239,9 +255,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
         ],
         "tools": tools if tools else None,
         "tool_choice": tool_choice,
-        "temperature": 0.0,
-        "top_k": 1,
-        "top_p": 1.0,
+        **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
     choice = res.body["choices"][0]
@@ -254,7 +268,12 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [PYTHON_TOOL], 'none'),
 ])
 def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
-    do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
+    global server
+    server.jinja = True
+    server.n_predict = n_predict
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
 
 
 @pytest.mark.slow
@@ -270,7 +289,12 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
     ("meta-llama-Llama-3.2-3B-Instruct",              256, [PYTHON_TOOL], 'none'),
 ])
 def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
-    do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
+    global server
+    server.jinja = True
+    server.n_predict = n_predict
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
 
 
 @pytest.mark.slow
@@ -281,6 +305,12 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
     ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None),
     ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"),
 
+    ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M",      None),
+    ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M",      "chatml"),
+
+    ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M",      None),
+    ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M",      "chatml"),
+
     ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None),
     ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"),
 
@@ -324,48 +354,52 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
-        "max_tokens": n_predict,
+    do_test_weather(server, max_tokens=n_predict)
+
+
+def do_test_weather(server: ServerProcess, **kwargs):
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "messages": [
             {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
             {"role": "user", "content": "What is the weather in Istanbul?"},
         ],
         "tools": [WEATHER_TOOL],
+        **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
     choice = res.body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
-    assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
-    assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
+    # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
+    assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
     actual_arguments = json.loads(tool_call["function"]["arguments"])
     assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
     location = actual_arguments["location"]
     assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
-    assert re.match('^Istanbul((TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
+    assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
 
 
 @pytest.mark.slow
 @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       "chatml"),
-    (None,                                           128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",         None),
+    (None,                                           128,  "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
+    (None,                                           128,  "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
     (None,                                           128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",         "chatml"),
     (None,                                           128,  "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",     ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
     (None,                                           128,  "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",       ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     (None,                                           128,  "bartowski/functionary-small-v3.2-GGUF:Q8_0",        ("meetkai/functionary-medium-v3.2", None)),
-    (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
     (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  None),
     (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  "chatml"),
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 
     # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
-    ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    # (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
+    # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
 def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
-    # n_predict = 512
     server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192 * 2
@@ -379,10 +413,14 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
+    do_test_calc_result(server, result_override, n_predict)
+
+
+def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
-            {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."},
+            {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
             {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
             {
                 "role": "assistant",
@@ -423,7 +461,8 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
                     }
                 }
             }
-        ]
+        ],
+        **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
     choice = res.body["choices"][0]
@@ -434,19 +473,19 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     if result_override is not None:
         assert re.match(result_override, content), f'Expected {result_override}, got {content}'
     else:
-        assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \
+        assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
             f'Expected something like "The y coordinate is 0.56.", got {content}'
 
 
 @pytest.mark.slow
 @pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
-    (128, 'deepseek',  "^The sum of 102 and 7 is 109.*",                        None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
-    (128,  None,        "^The sum of 102 and 7 is 109.*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (128, 'deepseek',  "^The sum of 102 and 7 is 109[\\s\\S]*",                        None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (128,  None,        "^The sum of 102 and 7 is 109[\\s\\S]*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
 
-    (1024, 'deepseek',  "To find the sum of.*",                                 "I need to calculate the sum of 102 and 7.*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (1024, 'none',      "^I need[\\s\\S]*?</think>\n?To find.*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "I need to calculate the sum of 102 and 7[\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'none',      "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
-    (1024, 'deepseek',  "To find the sum of.*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 ])
 def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
@@ -464,7 +503,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "user", "content": "What's the sum of 102 and 7?"},
@@ -476,7 +515,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
 
     content = choice["message"].get("content")
     if expect_content is None:
-        assert content is None, f'Expected no content in {choice["message"]}'
+        assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     else:
         assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
 
@@ -488,46 +527,46 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
 
 
 @pytest.mark.slow
-@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
-    (None,                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    # (None,                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
+@pytest.mark.parametrize("hf_repo,template_override", [
+    ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
-    (None,                 "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None),
-    (None,                 "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"),
+    ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None),
+    ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"),
 
-    (None,                 "bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai-functionary-medium-v3.2", None)),
-    (None,                 "bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
+    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai-functionary-medium-v3.2", None)),
+    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
 
-    ('{"code":"print("}',  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
-    (None,                 "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
+    # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
+    ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
 
-    (None,                 "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      ("meta-llama-Llama-3.2-3B-Instruct", None)),
-    (None,                 "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      "chatml"),
+    ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      ("meta-llama-Llama-3.2-3B-Instruct", None)),
+    ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",      None),
 
-    ('{"code":"print("}',  "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      ("meta-llama-Llama-3.2-3B-Instruct", None)),
-    (None,                 "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      "chatml"),
+    ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      ("meta-llama-Llama-3.2-3B-Instruct", None)),
+    ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      None),
 
-    (None,                 "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None),
-    (None,                 "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"),
+    ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None),
+    ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"),
 
-    (None,                 "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
-    (None,                 "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    "chatml"),
+    ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
+    ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    "chatml"),
 
-    (None,                 "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
-    (None,                 "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      "chatml"),
+    ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
+    ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      "chatml"),
 
-    (None,                 "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    (None,                 "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
 
-    # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
-    (None,                 "bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
+    ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
+    ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              "chatml"),
 ])
-def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
+    n_predict = 512 # High because of DeepSeek R1
     server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
-    server.n_predict = 512 # High because of DeepSeek R1
+    server.n_predict = n_predict
     server.model_hf_repo = hf_repo
     server.model_hf_file = None
     if isinstance(template_override, tuple):
@@ -537,31 +576,28 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/chat/completions", data={
-        "max_tokens": 256,
+
+    do_test_hello_world(server, max_tokens=n_predict)
+
+
+def do_test_hello_world(server: ServerProcess, **kwargs):
+    res = server.make_request("POST", "/v1/chat/completions", data={
         "messages": [
-            {"role": "system", "content": "You are a coding assistant."},
+            {"role": "system", "content": "You are a tool-calling agent."},
             {"role": "user", "content": "say hello world with python"},
         ],
         "tools": [PYTHON_TOOL],
-        # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n    print("Hello, World!")\nhello_world()` which is correct but a pain to test.
-        "temperature": 0.0,
-        "top_k": 1,
-        "top_p": 1.0,
+        **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
     choice = res.body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
-    assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
+    # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
-    actual_arguments = tool_call["function"]["arguments"]
-    if expected_arguments_override is not None:
-        assert actual_arguments == expected_arguments_override
-    else:
-        actual_arguments = json.loads(actual_arguments)
-        assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
-        code = actual_arguments["code"]
-        assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
-        assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
+    actual_arguments = json.loads(tool_call["function"]["arguments"])
+    assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
+    code = actual_arguments["code"]
+    assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
+    assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
index f32a439f6a7daf1df88ae7a638b62034c61d521b..ec2d8ec55853c36526d9746e8aa41c92934420d8 100644 (file)
@@ -67,6 +67,9 @@ class ServerProcess:
     id_slot: int | None = None
     cache_prompt: bool | None = None
     n_slots: int | None = None
+    ctk: str | None = None
+    ctv: str | None = None
+    fa: bool | None = None
     server_continuous_batching: bool | None = False
     server_embeddings: bool | None = False
     server_reranking: bool | None = False
@@ -84,6 +87,7 @@ class ServerProcess:
     reasoning_format: Literal['deepseek', 'none'] | None = None
     chat_template: str | None = None
     chat_template_file: str | None = None
+    server_path: str | None = None
 
     # session variables
     process: subprocess.Popen | None = None
@@ -97,7 +101,9 @@ class ServerProcess:
             self.server_port = int(os.environ["PORT"])
 
     def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
-        if "LLAMA_SERVER_BIN_PATH" in os.environ:
+        if self.server_path is not None:
+            server_path = self.server_path
+        elif "LLAMA_SERVER_BIN_PATH" in os.environ:
             server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
         elif os.name == "nt":
             server_path = "../../../build/bin/Release/llama-server.exe"
@@ -151,6 +157,12 @@ class ServerProcess:
             server_args.extend(["--ctx-size", self.n_ctx])
         if self.n_slots:
             server_args.extend(["--parallel", self.n_slots])
+        if self.ctk:
+            server_args.extend(["-ctk", self.ctk])
+        if self.ctv:
+            server_args.extend(["-ctv", self.ctv])
+        if self.fa is not None:
+            server_args.append("-fa")
         if self.n_predict:
             server_args.extend(["--n-predict", self.n_predict])
         if self.slot_save_path:
@@ -184,7 +196,7 @@ class ServerProcess:
             server_args.extend(["--chat-template-file", self.chat_template_file])
 
         args = [str(arg) for arg in [server_path, *server_args]]
-        print(f"bench: starting server with: {' '.join(args)}")
+        print(f"tests: starting server with: {' '.join(args)}")
 
         flags = 0
         if "nt" == os.name:
@@ -215,6 +227,10 @@ class ServerProcess:
                     return  # server is ready
             except Exception as e:
                 pass
+            # Check if process died
+            if self.process.poll() is not None:
+                raise RuntimeError(f"Server process died with return code {self.process.returncode}")
+
             print(f"Waiting for server to start...")
             time.sleep(0.5)
         raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
index 144d914c298efe2d8afcababbc62b8e971ec3b55..393e3927c7fd6bab6bcd6acc380aecfd69ddfe18 100644 (file)
@@ -607,6 +607,7 @@ static json oaicompat_completion_params_parse(
     inputs.use_jinja             = use_jinja;
     inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
     inputs.extract_reasoning     = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
         throw std::runtime_error("Cannot use custom grammar constraints with tools.");
     }
@@ -620,10 +621,7 @@ static json oaicompat_completion_params_parse(
     llama_params["grammar_lazy"]     = chat_params.grammar_lazy;
     auto grammar_triggers = json::array();
     for (const auto & trigger : chat_params.grammar_triggers) {
-        grammar_triggers.push_back({
-            {"word", trigger.word},
-            {"at_start", trigger.at_start},
-        });
+        grammar_triggers.push_back(trigger.to_json<json>());
     }
     llama_params["grammar_triggers"] = grammar_triggers;
     llama_params["preserved_tokens"] = chat_params.preserved_tokens;
index ee6e73915f136aa999f74ce4020b58e17d1d9028..d62792c0a67609ec39e2f29e3e165411c3e9dd8a 100644 (file)
@@ -1205,17 +1205,29 @@ extern "C" {
                           const char * grammar_str,
                           const char * grammar_root);
 
-    /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
-    /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
-    /// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
-    LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
+    DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
             const struct llama_vocab * vocab,
                           const char * grammar_str,
                           const char * grammar_root,
                          const char ** trigger_words,
                                 size_t num_trigger_words,
                    const llama_token * trigger_tokens,
-                                size_t num_trigger_tokens);
+                                size_t num_trigger_tokens),
+        "use llama_sampler_init_grammar_lazy_patterns instead");
+
+
+    /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
+    /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
+    /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
+    LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens);
+
 
     /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
     LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
index 72c30d1e1e08eccc807ea750a5a8b9ff924c6246..e4fd104fc9fe6d0c17293cfa98a0cdf9c44f45af 100644 (file)
@@ -9,7 +9,7 @@ These templates can be updated with the following commands:
 ./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B      > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
 ./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2          > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
 ./scripts/get_chat_template.py google/gemma-2-2b-it                          > models/templates/google-gemma-2-2b-it.jinja
-./scripts/get_chat_template.py meetkai/functionary-medium-v3.                > models/templates/meetkai-functionary-medium-v3.jinja
+./scripts/get_chat_template.py meetkai/functionary-medium-v3.1               > models/templates/meetkai-functionary-medium-v3.1.jinja
 ./scripts/get_chat_template.py meetkai/functionary-medium-v3.2               > models/templates/meetkai-functionary-medium-v3.2.jinja
 ./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct              > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja
 ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct              > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
index 9e190ae27de389fd3dc61ef78d2074fed046abd5..f2a18d62879b4e37249b566d6d85fd9485fb20e2 100644 (file)
@@ -10,3 +10,4 @@
 -r ./requirements/requirements-convert_hf_to_gguf_update.txt
 -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
 -r ./requirements/requirements-convert_lora_to_gguf.txt
+-r ./requirements/requirements-tool_bench.txt
index 94de59d7e18603b8be9dbb96ce289956b25a2c61..439db888636dc3d314bd2cdee60d5852a9b34812 100644 (file)
@@ -10,3 +10,4 @@
 -r ./requirements-convert_hf_to_gguf_update.txt
 -r ./requirements-convert_legacy_llama.txt
 -r ./requirements-convert_llama_ggml_to_gguf.txt
+-r ./requirements-tool_bench.txt
diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt
new file mode 100644 (file)
index 0000000..b94521f
--- /dev/null
@@ -0,0 +1,12 @@
+aiohttp~=3.9.3
+pytest~=8.3.3
+huggingface_hub~=0.23.2
+matplotlib~=3.10.0
+numpy~=1.26.4
+openai~=1.55.3
+pandas~=2.2.3
+prometheus-client~=0.20.0
+requests~=2.32.3
+wget~=3.2
+typer~=0.15.1
+seaborn~=0.13.2
index 05690b1385468b6db1ca20ee1f4fc1e9d368d311..e6775bfc5867c9ecef886dd7783796872cae786e 100755 (executable)
@@ -75,7 +75,7 @@ if __name__ == '__main__':
         logging.info(f'  - {m.hf_repo} / {m.hf_file}')
 
     cli_path = os.environ.get(
-        'LLAMA_SERVER_BIN_PATH',
+        'LLAMA_CLI_BIN_PATH',
         os.path.join(
             os.path.dirname(__file__),
             '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py
new file mode 100755 (executable)
index 0000000..0f406bc
--- /dev/null
@@ -0,0 +1,368 @@
+#!/usr/bin/env uv run
+'''
+    Simplistic tool call benchmarks for llama-server and ollama.
+
+    Essentially runs the tests at server/examples/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
+    and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
+
+    Simple usage example:
+
+        cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server
+
+        export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
+        export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
+
+        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M"      --output qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF      --ollama qwen2.5:1.5b-instruct-q4_K_M
+        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M"  --output qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF  --ollama qwen2.5-coder:7b
+
+        ./scripts/tool_bench.py plot *.jsonl                         # Opens window w/ heatmap
+        ./scripts/tool_bench.py plot qwen*.jsonl  --output qwen.png  # Saves heatmap to qwen.png
+
+    (please see ./scripts/tool_bench.sh for a more complete example)
+'''
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "pytest",
+#     "pandas",
+#     "matplotlib",
+#     "seaborn",
+#     "requests",
+#     "wget",
+#     "typer",
+# ]
+# ///
+from contextlib import contextmanager
+from pathlib import Path
+import re
+from statistics import mean, median
+from typing import Annotated, Dict, List, Optional, Tuple
+import atexit
+import json
+import logging
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import subprocess
+import sys
+import time
+import typer
+
+sys.path.insert(0, Path(__file__).parent.parent.as_posix())
+if True:
+    from examples.server.tests.utils import ServerProcess
+    from examples.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
+
+
+@contextmanager
+def scoped_server(sp: ServerProcess):
+    def stop():
+        nonlocal sp
+        if sp is not None:
+            sp.stop()
+            sp = None # type: ignore
+    atexit.register(stop)
+    yield sp
+    stop()
+
+
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+app = typer.Typer()
+
+
+@app.command()
+def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
+
+    lines: List[Dict] = []
+    for file in files:
+        if not file.exists():
+            logger.error(f"File not found: {file}")
+            continue
+
+        try:
+            with file.open() as f:
+                raw_data = f.read()
+            logger.info(f"Reading {file} ({len(raw_data)} bytes)")
+
+            for line_num, line in enumerate(raw_data.split('\n'), 1):
+                line = line.strip()
+                if not line:
+                    continue
+                try:
+                    record = json.loads(line)
+                    lines.append(record)
+                except json.JSONDecodeError as e:
+                    logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
+        except Exception as e:
+            logger.error(f"Error processing {file}: {e}")
+
+    if not lines:
+        raise Exception("No valid data was loaded")
+
+    data_dict: Dict[Tuple, float] = {}
+    models: List[str] = []
+    temps = set()
+    tests = set()
+    server_names = set()
+    total_counts = set()
+    for rec in lines:
+        try:
+            model = rec["model"]
+            temp = rec["temp"]
+            server_name = rec["server_name"]
+            test = rec["test"]
+            success = rec["success_ratio"]
+            success_count = rec["success_count"]
+            failure_count = rec["failure_count"]
+            total_count = success_count + failure_count
+            total_counts.add(total_count)
+
+            if test_regex and not re.search(test_regex, test):
+                continue
+
+            if server_regex and not re.search(server_regex, server_name):
+                continue
+
+            data_dict[(model, temp, server_name, test)] = success
+
+            if model not in models:
+                models.append(model)
+            temps.add(temp)
+            tests.add(test)
+            server_names.add(server_name)
+
+        except KeyError as e:
+            logger.warning(f"Missing required field in record: {e}")
+
+    if len(total_counts) > 1:
+        logger.warning(f"Total counts are not consistent: {total_counts}")
+
+    # Sort the collected values
+    temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
+    tests = list(sorted(tests))
+    server_names = list(sorted(server_names))
+
+    logger.info(f"Processed {len(lines)} lines")
+    logger.info(f"Found {len(data_dict)} valid data points")
+    logger.info(f"Models: {models}")
+    logger.info(f"Temperatures: {temps}")
+    logger.info(f"Tests: {tests}")
+    logger.info(f"Servers: {server_names}")
+
+    matrix: list[list[float]] = []
+    index: list[str] = []
+
+    all_cols = [
+        (server_name, test)
+        for server_name in server_names
+        for test in tests
+    ]
+    for model in models:
+        for temp in temps:
+            index.append(f"{model} @ {temp}")
+            row_vals = [
+                data_dict.get((model, temp, server_name, test), np.nan)
+                for server_name, test in all_cols
+            ]
+            matrix.append(row_vals)
+
+    columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
+
+    df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
+
+    plt.figure(figsize=(12, 6))
+
+    sns.heatmap(
+        df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
+        cbar_kws={"label": "Success Ratio"},
+    )
+
+    plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
+    plt.xlabel("Server & Test", labelpad=10)
+    plt.ylabel("Model @ Temperature", labelpad=10)
+
+    plt.xticks(rotation=45, ha='right')
+    plt.yticks(rotation=0)
+
+    plt.tight_layout()
+
+    if output:
+        plt.savefig(output, dpi=300, bbox_inches='tight')
+        logger.info(f"Plot saved to {output}")
+    else:
+        plt.show()
+
+
+@app.command()
+def run(
+    output: Annotated[Path, typer.Option(help="Output JSON file")],
+    model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
+    hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
+    chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
+    ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
+    llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
+    n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
+    temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
+    top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
+    top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
+    ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
+    ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
+    fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
+    seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
+    port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
+    force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
+    append: Annotated[bool, typer.Option(help="Append to output file")] = False,
+
+    test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
+    test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
+    test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
+):
+    # Check only one of output and append
+
+    n_predict = 512 # High because of DeepSeek R1
+    # n_ctx = 8192
+    n_ctx = 2048
+
+    assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
+
+    with output.open('a' if append else 'w') as output_file:
+
+        def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
+            request_kwargs = {**request_kwargs}
+            if temp is not None:
+                request_kwargs['temperature'] = temp
+            if top_p is not None:
+                request_kwargs['top_p'] = top_p
+            if top_k is not None:
+                request_kwargs['top_k'] = top_k
+            if seed is not None:
+                request_kwargs['seed'] = seed
+
+            request_kwargs['cache_prompt'] = False
+
+            tests = {}
+            if test_hello_world:
+                tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
+            if test_weather:
+                tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
+            if test_calc_result:
+                tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
+
+            for test_name, test in tests.items():
+                success_count = 0
+                failure_count = 0
+                failures = []
+                success_times = []
+                failure_times = []
+                logger.info(f"Running {test_name} ({server_name}, {model}): ")
+                for i in range(n):
+                    start_time = time.time()
+
+                    def elapsed():
+                        return time.time() - start_time
+
+                    try:
+                        test(server)
+                        success_times.append(elapsed())
+                        success_count += 1
+                        logger.info('success')
+                    except Exception as e:
+                        logger.error(f'failure: {e}')
+                        failure_count += 1
+                        failure_times.append(elapsed())
+                        failures.append(str(e))
+                        # import traceback
+                        # traceback.print_exc()
+                output_file.write(json.dumps({**output_kwargs, **dict(
+                    model=model,
+                    server_name=server_name,
+                    model_id=model_id,
+                    test=test_name,
+                    temp=t,
+                    top_p=top_p,
+                    top_k=top_k,
+                    ctk=ctk,
+                    ctv=ctv,
+                    seed=seed,
+                    success_ratio=float(success_count) / n,
+                    avg_time=mean(success_times + failure_times),
+                    median_time=median(success_times + failure_times),
+                    success_count=success_count,
+                    success_times=success_times,
+                    failure_count=failure_count,
+                    failure_times=failure_times,
+                    failures=list(set(failures)),
+                )}) + '\n')
+                output_file.flush()
+
+        for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
+            if hf is not None:
+
+                servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
+                if llama_baseline is not None:
+                    servers.append(('llama-server (baseline)', llama_baseline))
+
+                for server_name, server_path in servers:
+                    server = ServerProcess()
+                    server.n_ctx = n_ctx
+                    server.n_slots = 1
+                    server.jinja = True
+                    server.ctk = ctk
+                    server.ctv = ctv
+                    server.fa = fa
+                    server.n_predict = n_predict
+                    server.model_hf_repo = hf
+                    server.model_hf_file = None
+                    server.chat_template = chat_template
+                    server.server_path = server_path
+                    if port is not None:
+                        server.server_port = port
+                    # server.debug = True
+
+                    with scoped_server(server):
+                        server.start(timeout_seconds=TIMEOUT_SERVER_START)
+                        for ignore_chat_grammar in [False]:
+                            run(
+                                server,
+                                server_name=server_name,
+                                model_id=hf,
+                                temp=t,
+                                output_kwargs=dict(
+                                    chat_template=chat_template,
+                                ),
+                                request_kwargs=dict(
+                                    ignore_chat_grammar=ignore_chat_grammar,
+                                ),
+                            )
+
+            if ollama is not None:
+                server = ServerProcess()
+                server.server_port = 11434
+                server.server_host = "localhost"
+                subprocess.check_call(["ollama", "pull", ollama])
+
+                with scoped_server(server):
+                    run(
+                        server,
+                        server_name="ollama",
+                        model_id=ollama,
+                        temp=t,
+                        output_kwargs=dict(
+                            chat_template=None,
+                        ),
+                        request_kwargs=dict(
+                            model=ollama,
+                            max_tokens=n_predict,
+                            num_ctx = n_ctx,
+                        ),
+                    )
+
+
+if __name__ == "__main__":
+    app()
diff --git a/scripts/tool_bench.sh b/scripts/tool_bench.sh
new file mode 100755 (executable)
index 0000000..6c7616a
--- /dev/null
@@ -0,0 +1,66 @@
+#!/bin/bash
+set -euo pipefail
+
+cmake --build build -j
+
+export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
+export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
+
+if [ ! -x "$LLAMA_SERVER_BIN_PATH" ]; then
+    echo "Could not find llama-server binary at $LLAMA_SERVER_BIN_PATH"
+    exit 1
+fi
+if [ ! -d "$LLAMA_CACHE" ]; then
+    echo "Could not find llama cache at $LLAMA_CACHE, please set LLAMA_CACHE explicitly."
+    exit 1
+fi
+
+export ARGS=(
+    --llama-baseline="$(which llama-server)"
+    --n 30
+    --temp -1  # Leaves temperature parameter unset (use the server's default, e.g. 0.6 for ollama)
+    --temp 0
+    --temp 0.5
+    --temp 0.75
+    --temp 1
+    --temp 1.5
+    --temp 2
+    --temp 5
+    "$@"
+)
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 0.5B Q4_K_M"           --output ../qwenc0.5b.jsonl --hf bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:0.5b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 1.5B Q4_K_M"           --output ../qwenc1.5b.jsonl --hf bartowski/Qwen2.5-Coder-1.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:1.5b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 3B Q4_K_M"             --output ../qwenc3b.jsonl   --hf bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M   --ollama qwen2.5-coder:3b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M"             --output ../qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF:Q4_K_M   --ollama qwen2.5-coder:7b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 32B Q4_K_M"            --output ../qwenc32b.jsonl  --hf bartowski/Qwen2.5-Coder-32B-Instruct-GGUF:Q4_K_M  --ollama qwen2.5-coder:32B-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M"                 --output ../qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M       --ollama qwen2.5:1.5b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 3B Q4_K_M"                   --output ../qwen3b.jsonl    --hf bartowski/Qwen2.5-3B-Instruct-GGUF:Q4_K_M         --ollama qwen2.5:3b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M"                   --output ../qwen7b.jsonl    --hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M         --ollama qwen2.5:7b-instruct-q4_K_M
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M"         --output ../llama1b.jsonl   --hf bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M       --ollama llama3.2:1b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M"         --output ../llama3b.jsonl   --hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M       --ollama llama3.2:3b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M"         --output ../llama8b.jsonl   --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M  --ollama llama3.1:8b-instruct-q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 70B Q4_K_M"                 --output ../llama70b.jsonl  --hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo Q4_K_M"                  --output ../nemo.jsonl      --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M  --ollama mistral-nemo:12b-instruct-2407-q4_K_M
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 3 Llama 3.1 8B Q4_K_M"         --output ../hermes3.jsonl   --hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M       --ollama hermes3:8b-llama3.1-q4_K_M  --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
+./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 2 Pro Llama 3 8B Q4_K_M"       --output ../hermes2.jsonl   --hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M     --ollama hermes2:8b-llama3-q4_K_M    --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small V3.2 Q4_K_M"        --output ../funct3.2.jsonl  --hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
+./scripts/tool_bench.py run ${ARGS[@]} --model "FireFunction V2 IQ1_M"                --output ../firef2.jsonl    --hf bartowski/firefunction-v2-GGUF:IQ1_M                                                   --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Command R7B 12-2024 Q6_K_L"           --output ../c4ai.jsonl      --hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L                                         --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
+
+./scripts/tool_bench.py run ${ARGS[@]} --model "Gemma 2 2B Q8_0"                      --output ../gemma2.jsonl    --hf bartowski/gemma-2-2b-it-GGUF:Q8_0
+./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 4 Instruct Q4_K_M"                --output ../phi4.jsonl      --hf bartowski/phi-4-GGUF:Q4_K_M                       # --ollama phi4
+./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 3.5 Mini Instruct Q4_K_M"         --output ../phi3.5.jsonl    --hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M       # --ollama phi3.5:3.8b-mini-instruct-q4_K_M
+
+# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 7B Q6_K_L"   --output ../dsqw7.jsonl     --hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-7B tool_use )
+# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 32B Q4_K_M"  --output ../dsqw32.jsonl    --hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-32B tool_use )
+
+
+for f in ../*.jsonl; do
+    ./scripts/tool_bench.py plot "$f" --output ${f%.jsonl}.png || true
+done
index 98af1ba3976aefe26906d6296530280a5ebe1fc5..973b47ae063b08a6faec73294a6ab05d78e4ac56 100644 (file)
@@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
         /* .awaiting_trigger = */ false,
         /* .trigger_buffer = */   "",
         /* .trigger_tokens   = */ {},
-        /* .trigger_words    = */ {},
+        /* .trigger_patterns    = */ {},
     };
 }
 
@@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl(
                       const char * grammar_str,
                       const char * grammar_root,
                               bool lazy,
-                     const char ** trigger_words,
-                            size_t num_trigger_words,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns,
                const llama_token * trigger_tokens,
                             size_t num_trigger_tokens) {
     llama_grammar_parser parser;
 
     // if there is a grammar, parse it
-    if (!parser.parse(grammar_str)) {
-        return nullptr;
-    }
-
-    // will be empty (default) if there are parse errors
-    if (parser.rules.empty()) {
+    // rules will be empty (default) if there are parse errors
+    if (!parser.parse(grammar_str) || parser.rules.empty()) {
         fprintf(stderr, "%s: failed to parse grammar\n", __func__);
         return nullptr;
     }
@@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl(
     } while (true);
 
     std::vector<llama_token>    vec_trigger_tokens;
-    std::vector<std::string> vec_trigger_words;
+    std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
     for (size_t i = 0; i < num_trigger_tokens; i++) {
         GGML_ASSERT(trigger_tokens != nullptr);
         vec_trigger_tokens.push_back(trigger_tokens[i]);
     }
-    for (size_t i = 0; i < num_trigger_words; i++) {
-        GGML_ASSERT(trigger_words != nullptr);
-        vec_trigger_words.push_back(trigger_words[i]);
+    for (size_t i = 0; i < num_trigger_patterns; i++) {
+        GGML_ASSERT(trigger_patterns != nullptr);
+        auto & trigger = vec_trigger_patterns.emplace_back();
+        trigger.pattern = trigger_patterns[i];
+        trigger.regex = std::regex(trigger.pattern);
     }
 
     // Important: vec_rules has to be moved here, not copied, because stacks contains
@@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl(
         /* .awaiting_trigger = */ lazy,
         /* .trigger_buffer = */   "",
         std::move(vec_trigger_tokens),
-        std::move(vec_trigger_words),
+        std::move(vec_trigger_patterns),
     };
 }
 
@@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
 }
 
 struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
-    llama_grammar * result = new llama_grammar {
+    auto * result = new llama_grammar {
         grammar.vocab,
         grammar.rules,
         grammar.stacks,
@@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
         grammar.awaiting_trigger,
         grammar.trigger_buffer,
         grammar.trigger_tokens,
-        grammar.trigger_words,
+        grammar.trigger_patterns,
     };
 
     // redirect elements in stacks to point to new rules
@@ -1173,16 +1171,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
             return;
         } else {
-            // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
             grammar.trigger_buffer += piece;
-            for (const auto & word : grammar.trigger_words) {
-                auto pos = grammar.trigger_buffer.find(word);
-                if (pos != std::string::npos) {
+
+            std::smatch match;
+            for (const auto & trigger_pattern : grammar.trigger_patterns) {
+                if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
                     grammar.awaiting_trigger = false;
-                    auto constrained_str = grammar.trigger_buffer.substr(pos);
+                    // get from the first match to the end of the string
+                    auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
+                    // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
                     grammar.trigger_buffer.clear();
                     llama_grammar_accept_str(grammar, constrained_str);
-                    LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
+                    LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
                     return;
                 }
             }
index b143d834cfabe621003cef5b169872167940bf53..f8c291de999ac6f320936e7b1d6255a57ad65ac7 100644 (file)
@@ -3,6 +3,7 @@
 #include "llama.h"
 
 #include <map>
+#include <regex>
 #include <string>
 #include <vector>
 
@@ -105,6 +106,11 @@ struct llama_grammar_parser {
     void print(FILE * file);
 };
 
+struct llama_grammar_trigger_pattern {
+    std::string pattern;
+    std::regex  regex;
+};
+
 struct llama_grammar {
     // note: allow null vocab for testing (not great)
     const llama_vocab * vocab;
@@ -122,7 +128,10 @@ struct llama_grammar {
     bool                     awaiting_trigger = false; // Initialized to true for lazy grammars only
     std::string              trigger_buffer;           // Output buffered by lazy grammar. Will be cleared once trigger is found.
     std::vector<llama_token> trigger_tokens;           // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
-    std::vector<std::string> trigger_words;
+    std::vector<llama_grammar_trigger_pattern>
+                             trigger_patterns;         // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
+                                                       // string, and the grammar will be given the string from the first match group onwards.
+
 };
 
 //
@@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl(
                       const char * grammar_str,
                       const char * grammar_root,
                               bool lazy,
-                     const char ** trigger_words,
-                            size_t num_trigger_words,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns,
                const llama_token * trigger_tokens,
                             size_t num_trigger_tokens);
 
index f40bf2db83a80ffe6b82ec79798fbf0c49dad42c..c25977ca3a35c9879af75a6aabdd6bf92e3e1c50 100644 (file)
@@ -1449,7 +1449,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
                      const char ** trigger_words,
                             size_t num_trigger_words,
                const llama_token * trigger_tokens,
-                            size_t num_trigger_tokens);
+                            size_t num_trigger_tokens,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns);
 
 static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_grammar *) smpl->ctx;
@@ -1457,12 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
         return;
     }
 
-    std::vector<const char *>  trigger_words;
-    for (auto & word : ctx->grammar->trigger_words) {
-        trigger_words.push_back(word.c_str());
+    std::vector<const char *>  trigger_patterns_c;
+    trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
+    for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
+        trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
     }
+
     auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
-                                                 ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
+                                                 ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
                                                  ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
 
     llama_grammar_free_impl(ctx->grammar);
@@ -1472,7 +1476,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
 static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
-    auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
+    auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
 
     // copy the state
     {
@@ -1516,15 +1520,33 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
                      const char ** trigger_words,
                             size_t num_trigger_words,
                const llama_token * trigger_tokens,
-                            size_t num_trigger_tokens) {
+                            size_t num_trigger_tokens,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns) {
     auto * ctx = new llama_sampler_grammar;
 
     if (grammar_str != nullptr && grammar_str[0] != '\0') {
+        // TODO: remove trigger_words support.
+        if (trigger_words != nullptr && num_trigger_words > 0) {
+            GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
+            std::string trigger_pattern("[\\s\\S]*?(");
+            for (size_t i = 0; i < num_trigger_words; ++i) {
+                static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
+                if (i > 0) {
+                    trigger_pattern += "|";
+                }
+                trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
+            }
+            trigger_pattern += ")[\\s\\S]*";
+            auto trigger_pattern_c = trigger_pattern.c_str();
+            trigger_patterns = &trigger_pattern_c;
+            num_trigger_patterns = 1;
+        }
         *ctx = {
             /* .vocab        = */ vocab,
             /* .grammar_str  = */ grammar_str,
             /* .grammar_root = */ grammar_root,
-            /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
+            /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
         };
     } else {
         *ctx = {
@@ -1545,7 +1567,7 @@ struct llama_sampler * llama_sampler_init_grammar(
         const struct llama_vocab * vocab,
                       const char * grammar_str,
                       const char * grammar_root) {
-    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
 }
 
 struct llama_sampler * llama_sampler_init_grammar_lazy(
@@ -1556,7 +1578,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy(
                             size_t num_trigger_words,
                const llama_token * trigger_tokens,
                             size_t num_trigger_tokens) {
-    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
+}
+
+struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                     const char ** trigger_patterns,
+                            size_t num_trigger_patterns,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens) {
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
 }
 
 // penalties
index 64359230548591aa9f733d3d103d8b6786775731..35a307c632e689d5be93acc10c91329e2a2a8f91 100644 (file)
@@ -237,12 +237,35 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
             auto earliest_trigger_pos = std::string::npos;
             auto constrained = data.delta;
             for (const auto & trigger : data.params.grammar_triggers) {
-                auto pos = constrained.find(trigger.word);
-                if (pos == std::string::npos) {
-                    continue;
+                size_t pos = std::string::npos;
+                std::smatch match;
+                switch (trigger.type) {
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
+                    {
+                        const auto & word = trigger.value;
+                        pos = constrained.find(word);
+                        break;
+                    }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
+                    {
+                        const auto & pattern = trigger.value;
+                        if (std::regex_search(constrained, match, std::regex(pattern))) {
+                            pos = match.position();
+                        }
+                        break;
+                    }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
+                    {
+                        const auto & pattern = trigger.value;
+                        if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
+                            pos = 0;
+                        }
+                        break;
+                    }
+                    default:
+                        throw std::runtime_error("Unknown trigger type");
                 }
-                if (pos > 0 && trigger.at_start) {
-                    fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
+                if (pos == std::string::npos) {
                     continue;
                 }
                 if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
@@ -260,7 +283,8 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
 
             if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
                 throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
-                                            "\n\nGrammar: " + data.params.grammar);
+                    "\n\nConstrained: " + constrained +
+                    "\n\nGrammar: " + data.params.grammar);
             }
         }
     }
@@ -640,6 +664,93 @@ static void test_template_output_parsers() {
                 inputs_tools)
                 .format);
 
+        // Test parsing
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<tool_call>\n"
+            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</tool_call>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<function=special_function>{\"arg1\": 1}</function>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<function name=\"special_function\">\n"
+            "{\"arg1\": 1}\n"
+            "</function>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<tool>\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</tool>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<tools>\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</tools>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<response>\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</response>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```xml\n"
+            "<response>\n"
+            "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</response>\n"
+            "```",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```xml\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "```",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "```",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```\n"
+            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "```",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```json\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "```",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "```json\n"
+            "\n"
+            "                    <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
+            "                    </function_call> \n"
+            "``` ",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<json>\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</json>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<xml>\n"
+            "  {\n"
+            "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
+            "  }\n"
+            "</xml>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "<JSON>\n"
+            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+            "</JSON>",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_call, common_chat_parse(
+            "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<tool_call>\n"
@@ -789,7 +900,7 @@ static void test_template_output_parsers() {
 }
 
 int main(int argc, char ** argv) {
-    try {
+    // try {
 #ifndef _WIN32
         if (argc > 1) {
             common_chat_templates_inputs inputs;
@@ -827,8 +938,8 @@ int main(int argc, char ** argv) {
             std::cout << "\n[chat] All tests passed!" << '\n';
         }
         return 0;
-    } catch (const std::exception & e) {
-        std::cerr << "Error: " << e.what() << '\n';
-        return 1;
-    }
+    // } catch (const std::exception & e) {
+    //     std::cerr << "Error: " << e.what() << '\n';
+    //     return 1;
+    // }
 }
index f38994c925e09a99edaea831da1e066d4011449f..4d78e914269f3481c78a77b960e4c599fa760d0e 100755 (executable)
@@ -91,7 +91,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([0] | [1-9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -104,7 +104,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([1-9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -117,7 +117,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -130,7 +130,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -143,7 +143,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -156,7 +156,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -169,7 +169,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -182,7 +182,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -195,7 +195,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -208,7 +208,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -221,7 +221,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -234,7 +234,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -248,7 +248,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -262,7 +262,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -276,7 +276,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -290,7 +290,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -304,7 +304,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -340,7 +340,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
             root ::= object
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
             value ::= object | array | string | number | boolean | null
         )"""
@@ -363,7 +363,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             date-time ::= date "T" time
             date-time-string ::= "\"" date-time "\"" space
             root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
             time-string ::= "\"" time "\"" space
             tuple-0 ::= date-string
@@ -382,7 +382,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "\"" char* "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -396,7 +396,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "\"" char+ "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -410,7 +410,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "\"" char{3,} "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -424,7 +424,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "\"" char{0,3} "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -439,7 +439,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "\"" char{1,4} "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -451,7 +451,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("true" | "false") space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -464,7 +464,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             root ::= ("-"? integral-part) space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -476,7 +476,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "\"foo\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -488,7 +488,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "123" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -500,7 +500,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -514,7 +514,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "[" space (string ("," space string)*)? "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -531,7 +531,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             null ::= "null" space
             root ::= alternative-0 | null
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -545,7 +545,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "[" space string "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -562,7 +562,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "[" space string "," space number "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -577,7 +577,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             decimal-part ::= [0-9]{1,16}
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -593,7 +593,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             boolean ::= ("true" | "false") space
             root ::= "[" space boolean ("," space boolean)+ "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -609,7 +609,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             boolean ::= ("true" | "false") space
             root ::= "[" space boolean? "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -625,7 +625,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             boolean ::= ("true" | "false") space
             root ::= "[" space (boolean ("," space boolean)?)? "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -646,7 +646,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             item ::= number | integer
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "[" space item ("," space item){2,4} "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -665,7 +665,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
             root ::= "[" space item ("," space item){2,4} "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -684,7 +684,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         R"""(
             item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
             root ::= "[" space item ("," space item){2,4} "]" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -697,7 +697,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -710,7 +710,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "\"" ("[]{}()|+*?") "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -723,7 +723,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "\"" ("\"") "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -736,7 +736,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -751,7 +751,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             dot ::= [^\x0A\x0D]
             root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
             root-1 ::= [0-9]
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -779,7 +779,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             c-kv ::= "\"c\"" space ":" space string
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -799,7 +799,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             a-kv ::= "\"a\"" space ":" space string
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "{" space  (a-kv )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -823,7 +823,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             c-kv ::= "\"c\"" space ":" space string
             char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
             root ::= "{" space  (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -849,7 +849,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             d-kv ::= "\"d\"" space ":" space string
             d-rest ::= ( "," space c-kv )?
             root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -869,7 +869,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "{" space  (additional-kv ( "," space additional-kv )* )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -891,7 +891,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
             root ::= object
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
             value ::= object | array | string | number | boolean | null
         )"""
@@ -913,7 +913,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
             root ::= object
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
             value ::= object | array | string | number | boolean | null
         )"""
@@ -928,7 +928,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
         })""",
         R"""(
             root ::= "{" space  "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -952,7 +952,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -977,7 +977,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "{" space  (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1004,7 +1004,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1030,7 +1030,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             root ::= ("-"? integral-part) space
             root0 ::= "{" space  (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1055,7 +1055,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integer ::= ("-"? integral-part) space
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             root ::= "{" space  (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1080,7 +1080,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integer ::= ("-"? integral-part) space
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             root ::= "{" space  (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1109,7 +1109,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             foo ::= "{" space foo-a-kv "}" space
             foo-a-kv ::= "\"a\"" space ":" space string
             root ::= foo
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
             string ::= "\"" char* "\"" space
         )"""
     });
@@ -1143,7 +1143,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= alternative-0 | alternative-1
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1187,7 +1187,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             integral-part ::= [0] | [1-9] [0-9]{0,15}
             number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
             root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 
@@ -1235,7 +1235,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
             number-number-kv ::= "\"number\"" space ":" space number-number
             number-number-root-kv ::= "\"root\"" space ":" space number
             root ::= "{" space number-kv "}" space
-            space ::= | " " | "\n" [ \t]{0,20}
+            space ::= | " " | "\n"{1,2} [ \t]{0,20}
         )"""
     });
 }