]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`server`: inject date_string in llama 3.x template + fix date for firefunction v2...
authorOlivier Chafik <redacted>
Thu, 15 May 2025 01:39:51 +0000 (02:39 +0100)
committerGitHub <redacted>
Thu, 15 May 2025 01:39:51 +0000 (02:39 +0100)
* Inject date_string in llama 3.x + fix for functionary v2

https://github.com/ggml-org/llama.cpp/issues/12729

* move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode

Co-authored-by: Sigbjørn Skjæret <redacted>
* generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation

---------

Co-authored-by: ochafik <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
common/chat.cpp
common/chat.h
tests/test-chat.cpp
tools/server/tests/unit/test_template.py [new file with mode: 0644]
tools/server/tests/unit/test_tool_call.py

index ad3d4aa99a926a6f38f992cb07f03c0b94278296..f138c7bcafcfafafe8707da23f0df046f77db75b 100644 (file)
@@ -6,6 +6,15 @@
 
 #include <optional>
 
+static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
+    auto time = std::chrono::system_clock::to_time_t(now);
+    auto local_time = *std::localtime(&time);
+    std::ostringstream ss;
+    ss << std::put_time(&local_time, format.c_str());
+    auto res = ss.str();
+    return res;
+}
+
 typedef minja::chat_template common_chat_template;
 
 struct common_chat_templates {
@@ -24,6 +33,7 @@ struct templates_params {
     std::string grammar;
     bool add_generation_prompt = true;
     bool extract_reasoning     = true;
+    std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
 };
 
 common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
     }
 }
 
-static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
+static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
     auto builtin_tools = json::array();
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
-    data.grammar = build_grammar([&](const common_grammar_builder & builder) {
-        std::vector<std::string> tool_rules;
+    if (!inputs.tools.is_null()) {
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+        data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+            std::vector<std::string> tool_rules;
 
-        auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
-            if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
-                // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
-                // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
-                expect_tool_parameters(name, parameters, {"query"});
-            } else if (name == "python" || name == "code_interpreter") {
-                // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
-                expect_tool_parameters(name, parameters, {"code"});
-            } else {
-                return false;
-            }
+            auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
+                if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
+                    // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
+                    // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
+                    expect_tool_parameters(name, parameters, {"query"});
+                } else if (name == "python" || name == "code_interpreter") {
+                    // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
+                    expect_tool_parameters(name, parameters, {"code"});
+                } else {
+                    return false;
+                }
 
-            std::vector<std::string> kvs;
-            for (const auto & [key, value] : parameters.at("properties").items()) {
-                kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
-            }
+                std::vector<std::string> kvs;
+                for (const auto & [key, value] : parameters.at("properties").items()) {
+                    kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
+                }
 
-            tool_rules.push_back(
-                builder.add_rule(
-                    name + "-call",
-                    "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
-            builtin_tools.push_back(name);
+                tool_rules.push_back(
+                    builder.add_rule(
+                        name + "-call",
+                        "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
+                builtin_tools.push_back(name);
 
-            return true;
-        };
+                return true;
+            };
 
-        foreach_function(inputs.tools, [&](const json & tool) {
-            const auto & function = tool.at("function");
-            std::string name = function.at("name");
-            auto parameters = function.at("parameters");
-            builder.resolve_refs(parameters);
+            foreach_function(inputs.tools, [&](const json & tool) {
+                const auto & function = tool.at("function");
+                std::string name = function.at("name");
+                auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
 
-            // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
-            if (allow_python_tag_builtin_tools) {
-                handle_builtin_tool(name, parameters);
+                // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
+                if (allow_python_tag_builtin_tools) {
+                    handle_builtin_tool(name, parameters);
+                }
+                tool_rules.push_back(
+                    builder.add_rule(
+                        name + "-call",
+                        "\"{\" space "
+                        "( \"\\\"type\\\"\"       space \":\" space \"\\\"function\\\"\"     space \",\" space )? "
+                        "  \"\\\"name\\\"\"       space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
+                        "  \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
+                        "\"}\" space"));
+            });
+            // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
+            data.grammar_triggers.push_back({
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+                "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
+            });
+            if (!builtin_tools.empty()) {
+                data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
+                data.preserved_tokens.push_back("<|python_tag|>");
             }
-            tool_rules.push_back(
-                builder.add_rule(
-                    name + "-call",
-                    "\"{\" space "
-                    "( \"\\\"type\\\"\"       space \":\" space \"\\\"function\\\"\"     space \",\" space )? "
-                    "  \"\\\"name\\\"\"       space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
-                    "  \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
-                    "\"}\" space"));
+            // Allow a few empty lines on top of the usual constrained json schema space rule.
+            builder.add_rule("root", string_join(tool_rules, " | "));
+            data.additional_stops.push_back("<|eom_id|>");
         });
-        // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
-        data.grammar_triggers.push_back({
-            COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
-            "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
-        });
-        if (!builtin_tools.empty()) {
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
-            data.preserved_tokens.push_back("<|python_tag|>");
-        }
-        // Allow a few empty lines on top of the usual constrained json schema space rule.
-        builder.add_rule("root", string_join(tool_rules, " | "));
-    });
-    data.additional_stops.push_back("<|eom_id|>");
+        data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
+            ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
+            : COMMON_CHAT_FORMAT_LLAMA_3_X;
+    } else {
+        data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    }
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
+        {"date_string", format_time(inputs.now, "%d %b %Y")},
         {"tools_in_user_message", false},
         {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
     });
-    data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
-        ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
-        : COMMON_CHAT_FORMAT_LLAMA_3_X;
     return data;
 }
 static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
     LOG_DBG("%s\n", __func__);
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
-        {"datetime", "Jan 29 2025 13:00:00 GMT"},
+        {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
         {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
     });
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
 static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
     common_chat_params data;
-    json tools = inputs.tools.is_null() ? inputs.tools : json::array();
-    std::string python_code_argument_name;
-    auto has_raw_python = false;
 
-    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
-    data.grammar = build_grammar([&](const common_grammar_builder & builder) {
-        std::vector<std::string> tool_rules;
-        foreach_function(inputs.tools, [&](const json & tool) {
-            const auto & function = tool.at("function");
-            const auto & parameters = function.at("parameters");
-            std::string name = function.at("name");
-            if (name == "python" || name == "ipython") {
-                if (!parameters.contains("type")) {
-                    throw std::runtime_error("Missing type in python tool");
-                }
-                has_raw_python = true;
-                const auto & type = parameters.at("type");
-                if (type == "object") {
-                    auto properties = parameters.at("properties");
-                    for (auto it = properties.begin(); it != properties.end(); ++it) {
-                        if (it.value().at("type") == "string") {
-                            if (!python_code_argument_name.empty()) {
-                                throw std::runtime_error("Multiple string arguments found in python tool");
+    if (!inputs.tools.is_null()) {
+        std::string python_code_argument_name;
+        auto has_raw_python = false;
+
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+        data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+            std::vector<std::string> tool_rules;
+            foreach_function(inputs.tools, [&](const json & tool) {
+                const auto & function = tool.at("function");
+                const auto & parameters = function.at("parameters");
+                std::string name = function.at("name");
+                if (name == "python" || name == "ipython") {
+                    if (!parameters.contains("type")) {
+                        throw std::runtime_error("Missing type in python tool");
+                    }
+                    has_raw_python = true;
+                    const auto & type = parameters.at("type");
+                    if (type == "object") {
+                        auto properties = parameters.at("properties");
+                        for (auto it = properties.begin(); it != properties.end(); ++it) {
+                            if (it.value().at("type") == "string") {
+                                if (!python_code_argument_name.empty()) {
+                                    throw std::runtime_error("Multiple string arguments found in python tool");
+                                }
+                                python_code_argument_name = it.key();
                             }
-                            python_code_argument_name = it.key();
                         }
+                        if (python_code_argument_name.empty()) {
+                            throw std::runtime_error("No string argument found in python tool");
+                        }
+                    } else if (type != "string") {
+                        throw std::runtime_error("Invalid type in python tool: " + type.dump());
                     }
-                    if (python_code_argument_name.empty()) {
-                        throw std::runtime_error("No string argument found in python tool");
-                    }
-                } else if (type != "string") {
-                    throw std::runtime_error("Invalid type in python tool: " + type.dump());
                 }
+                tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
+            });
+            if (has_raw_python) {
+                tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
+                data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
+                data.preserved_tokens.push_back("<|python_tag|>");
             }
-            tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
+            auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
+            builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
+            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
         });
-        if (has_raw_python) {
-            tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
-            data.preserved_tokens.push_back("<|python_tag|>");
-        }
-        auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
-        builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
-        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
-    });
+        data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
+    } else {
+        data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    }
 
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     // TODO: if (has_raw_python)
-    data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
     return data;
 }
 static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
     params.extract_reasoning = inputs.extract_reasoning;
     params.tool_choice = inputs.tool_choice;
     params.grammar = inputs.grammar;
+    params.now = inputs.now;
     if (!inputs.json_schema.empty()) {
         params.json_schema = json::parse(inputs.json_schema);
     }
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_firefunction_v2(tmpl, params);
     }
 
-    // Plain handler (no tools)
-    if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
-        return common_chat_params_init_without_tools(tmpl, params);
-    }
-
     // Functionary v3.1 (w/ tools)
     if (src.find("<|start_header_id|>") != std::string::npos
         && src.find("<function=") != std::string::npos) {
         return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
     }
 
-    // Llama 3.1, 3.2, 3.3 (w/ tools)
+    // Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
     if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
         auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
-        return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
+        return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
+    }
+
+    // Plain handler (no tools)
+    if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
+        return common_chat_params_init_without_tools(tmpl, params);
     }
 
     // Mistral Nemo (w/ tools)
index 9aad84e880448f0c5a4c0076afe9395c964c87ea..d26a09c2f7c4fdc8b82ed53d5272bc0ed57d935f 100644 (file)
@@ -3,6 +3,7 @@
 #pragma once
 
 #include "common.h"
+#include <chrono>
 #include <string>
 #include <vector>
 
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
     common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
     bool parallel_tool_calls = false;
     bool extract_reasoning     = true;
+    std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
 };
 
 struct common_chat_params {
index fa7aed82dfaa86108718f6a1e464992e3595f1c0..4d70da8c32c91024f5fb52ace7b40f99249dec9e 100644 (file)
@@ -832,7 +832,9 @@ static void test_template_output_parsers() {
         assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
                       common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
-                      common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+            common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
+                        common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
diff --git a/tools/server/tests/unit/test_template.py b/tools/server/tests/unit/test_template.py
new file mode 100644 (file)
index 0000000..cf9f96a
--- /dev/null
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+import pytest
+
+# ensure grandparent path is in sys.path
+from pathlib import Path
+import sys
+
+from unit.test_tool_call import TEST_TOOL
+path = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(path))
+
+import datetime
+from utils import *
+
+server: ServerProcess
+
+TIMEOUT_SERVER_START = 15*60
+
+@pytest.fixture(autouse=True)
+def create_server():
+    global server
+    server = ServerPreset.tinyllama2()
+    server.model_alias = "tinyllama-2"
+    server.server_port = 8081
+    server.n_slots = 1
+
+
+@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
+@pytest.mark.parametrize("template_name,format", [
+    ("meta-llama-Llama-3.3-70B-Instruct",    "%d %b %Y"),
+    ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
+])
+def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
+    global server
+    server.jinja = True
+    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
+    server.start(timeout_seconds=TIMEOUT_SERVER_START)
+
+    res = server.make_request("POST", "/apply-template", data={
+        "messages": [
+            {"role": "user", "content": "What is today?"},
+        ],
+        "tools": tools,
+    })
+    assert res.status_code == 200
+    prompt = res.body["prompt"]
+
+    today_str = datetime.date.today().strftime(format)
+    assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
index 569c2a1f8ea31b96fab2f32c5c87a4064c2b485c..1f2c151c1a0fa6817920c60af88ac50c040b5a6c 100755 (executable)
@@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
 ])
 def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
     global server
-    n_predict = 512
+    n_predict = 1024
     # server = ServerPreset.stories15m_moe()
     server.jinja = True
     server.n_predict = n_predict