]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`server`: streaming of tool calls and thoughts when `--jinja` is on (#12379)
authorOlivier Chafik <redacted>
Sun, 25 May 2025 00:48:08 +0000 (01:48 +0100)
committerGitHub <redacted>
Sun, 25 May 2025 00:48:08 +0000 (01:48 +0100)
* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <redacted>
Co-authored-by: Olivier Chafik <redacted>
23 files changed:
common/CMakeLists.txt
common/chat-parser.cpp [new file with mode: 0644]
common/chat-parser.h [new file with mode: 0644]
common/chat.cpp
common/chat.h
common/common.h
common/json-partial.cpp [new file with mode: 0644]
common/json-partial.h [new file with mode: 0644]
common/sampling.cpp
docs/function-calling.md
models/templates/Qwen-QwQ-32B.jinja [new file with mode: 0644]
models/templates/README.md
scripts/tool_bench.py
src/llama-grammar.cpp
tests/CMakeLists.txt
tests/test-chat-parser.cpp [new file with mode: 0644]
tests/test-chat.cpp
tests/test-json-partial.cpp [new file with mode: 0644]
tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py
tools/server/tests/unit/test_tool_call.py
tools/server/tests/utils.py
tools/server/utils.hpp

index a7ff3ac16c446985f296c1f8d35d4d7a49225a73..dac4cc770eb9d77d5b3a654568ac5e4b4dea654c 100644 (file)
@@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
     base64.hpp
     chat.cpp
     chat.h
+    chat-parser.cpp
+    chat-parser.h
     common.cpp
     common.h
     console.cpp
     console.h
     json-schema-to-grammar.cpp
     json.hpp
+    json-partial.h
+    json-partial.cpp
     llguidance.cpp
     log.cpp
     log.h
diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp
new file mode 100644 (file)
index 0000000..5447568
--- /dev/null
@@ -0,0 +1,376 @@
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+#include <optional>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+using json = nlohmann::ordered_json;
+
+common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
+    : input_(input), is_partial_(is_partial), syntax_(syntax)
+{
+    result_.role = "assistant";
+
+    while (true) {
+        std::string id = std::to_string(std::rand());
+        if (input.find(id) == std::string::npos) {
+            healing_marker_ = id;
+            break;
+        }
+    }
+}
+
+std::string common_chat_msg_parser::str(const common_string_range & rng) const {
+    GGML_ASSERT(rng.begin <= rng.end);
+    return input_.substr(rng.begin, rng.end - rng.begin);
+}
+
+void common_chat_msg_parser::add_content(const std::string &content) {
+    result_.content += content;
+}
+
+void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
+    result_.reasoning_content += reasoning_content;
+}
+
+bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
+    if (name.empty()) {
+        return false;
+    }
+
+    common_chat_tool_call tool_call;
+    tool_call.name = name;
+    tool_call.arguments = arguments;
+    tool_call.id = id;
+
+    // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
+    result_.tool_calls.emplace_back(tool_call);
+    return true;
+}
+bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
+    std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
+    std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
+    std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
+    return add_tool_call(name, id, arguments);
+}
+
+bool common_chat_msg_parser::add_tool_calls(const json & arr) {
+    for (const auto & item : arr) {
+        if (!add_tool_call(item)) {
+            return false;
+        }
+    }
+    return true;
+}
+void common_chat_msg_parser::finish() {
+    if (!is_partial_ && pos_ != input_.size()) {
+        throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
+    }
+}
+
+bool common_chat_msg_parser::consume_spaces() {
+    const auto length = input_.size();
+    auto consumed = false;
+    while (pos_ < length && std::isspace(input_[pos_])) {
+        ++pos_;
+        consumed = true;
+    }
+    return consumed;
+}
+
+bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
+    auto pos = pos_;
+    for (auto i = 0u; i < literal.size(); ++i) {
+        if (pos >= input_.size()) {
+            return false;
+        }
+        if (input_[pos] != literal[i]) {
+            return false;
+        }
+        ++pos;
+    }
+    pos_ = pos;
+    return true;
+}
+
+std::optional<common_chat_msg_parser::find_regex_result>  common_chat_msg_parser::try_find_literal(const std::string & literal) {
+    auto idx = input_.find(literal, pos_);
+    if (idx != std::string::npos) {
+        find_regex_result res;
+        res.prelude = input_.substr(pos_, idx - pos_);
+        auto end = idx + literal.size();
+        res.groups.emplace_back(common_string_range{idx, end});
+        move_to(end);
+        return res;
+    }
+    if (is_partial_) {
+        idx = string_find_partial_stop(input_, literal);
+        if (idx != std::string::npos && idx >= pos_) {
+            find_regex_result res;
+            res.prelude = input_.substr(pos_, idx - pos_);
+            auto end = input_.size();
+            res.groups.emplace_back(common_string_range{idx, end});
+            move_to(end);
+            return res;
+        }
+    }
+    return std::nullopt;
+}
+
+void common_chat_msg_parser::consume_literal(const std::string & literal) {
+    if (!try_consume_literal(literal)) {
+        throw common_chat_msg_partial_exception(literal);
+    }
+}
+
+bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
+    auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
+        auto stripped_reasoning = string_strip(reasoning);
+        if (stripped_reasoning.empty()) {
+            return;
+        }
+        if (syntax_.reasoning_in_content) {
+            add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
+            add_content(stripped_reasoning);
+            if (closed) {
+                add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
+            }
+        } else {
+            add_reasoning_content(stripped_reasoning);
+        }
+    };
+    if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
+        if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
+            if (auto res = try_find_literal(end_think)) {
+                handle_reasoning(res->prelude, /* closed */ true);
+                consume_spaces();
+                return true;
+            }
+            auto rest = consume_rest();
+            if (!rest.empty()) {
+                handle_reasoning(rest, /* closed */ !is_partial());
+            }
+            if (!syntax_.thinking_forced_open) {
+                throw common_chat_msg_partial_exception(end_think);
+            }
+            return true;
+        }
+    }
+    return false;
+}
+
+std::string common_chat_msg_parser::consume_rest() {
+    auto rest = input_.substr(pos_);
+    pos_ = input_.size();
+    return rest;
+}
+
+// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
+std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
+    auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
+    if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
+        return std::nullopt;
+    }
+    if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
+        if (is_partial()) {
+            throw common_chat_msg_partial_exception(regex.str());
+        }
+        return std::nullopt;
+    }
+    auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
+    pos_ = m.groups[0].end;
+
+    return find_regex_result{prelude, m.groups};
+}
+
+common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
+    if (auto result = try_consume_regex(regex)) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception(regex.str());
+}
+
+std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
+    auto m = regex.search(input_, pos_);
+    if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
+        return std::nullopt;
+    }
+    if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
+        if (is_partial()) {
+            throw common_chat_msg_partial_exception(regex.str());
+        }
+        return std::nullopt;
+    }
+    if (m.groups[0].begin != pos_) {
+        // Didn't match at the current position.
+        return std::nullopt;
+    }
+    pos_ = m.groups[0].end;
+
+    return find_regex_result {
+        /* .prelude = */ "",
+        m.groups,
+    };
+}
+
+std::optional<common_json> common_chat_msg_parser::try_consume_json() {
+    auto it = input_.cbegin() + pos_;
+    const auto end = input_.cend();
+    common_json result;
+    if (!common_json_parse(it, end, healing_marker_, result)) {
+        return std::nullopt;
+    }
+    pos_ = std::distance(input_.cbegin(), it);
+    if (result.healing_marker.marker.empty()) {
+        // No healing marker, just return the parsed json
+        return result;
+    }
+    if (!is_partial()) {
+        throw common_chat_msg_partial_exception("JSON");
+    }
+    return result;
+}
+
+common_json common_chat_msg_parser::consume_json() {
+    if (auto result = try_consume_json()) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception("JSON");
+}
+
+common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
+    const std::vector<std::vector<std::string>> & args_paths,
+    const std::vector<std::vector<std::string>> & content_paths
+) {
+    if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception("JSON");
+}
+
+std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
+    const std::vector<std::vector<std::string>> & args_paths,
+    const std::vector<std::vector<std::string>> & content_paths
+) {
+    auto partial = try_consume_json();
+    if (!partial) {
+        return std::nullopt;
+    }
+    auto is_arguments_path = [&](const std::vector<std::string> & path) {
+        return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
+    };
+    auto is_content_path = [&](const std::vector<std::string> & path) {
+        return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
+    };
+
+    if (partial->healing_marker.marker.empty()) {
+        if (args_paths.empty()) {
+            // No arguments to dump, and JSON was parsed fully.
+            return consume_json_result {
+                partial->json,
+                /* .is_partial = */ false,
+            };
+        }
+        if (is_arguments_path({})) {
+            // Entire JSON is the arguments and was parsed fully.
+            return consume_json_result {
+                partial->json.dump(),
+                /* .is_partial = */ false,
+            };
+        }
+    }
+
+    LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
+
+    auto found_healing_marker = false;
+    std::vector<std::string> path;
+    std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
+        if (is_arguments_path(path)) {
+            auto arguments = j.dump();
+            if (is_partial() && !partial->healing_marker.marker.empty()) {
+                auto idx = arguments.find(partial->healing_marker.json_dump_marker);
+                if (idx != std::string::npos) {
+                    arguments.resize(idx);
+                    found_healing_marker = true;
+                }
+                if (arguments == "\"") {
+                    // This happens because of completing `:"$magic` after `"arguments"`
+                    arguments = "";
+                }
+            }
+            return arguments;
+        }
+        if (is_content_path(path)) {
+            if (!j.is_string()) {
+                throw std::runtime_error("Content path must be a string");
+            }
+            std::string str = j;
+            auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
+            if (idx != std::string::npos) {
+                str.resize(idx);
+                found_healing_marker = true;
+            }
+            return str;
+        }
+        if (j.is_object()) {
+            auto obj = json::object();
+            for (const auto & p : j.items()) {
+                const auto & key = p.key();
+                const auto & value = p.value();
+                const std::string key_str = key; // NOLINT
+                auto idx = key_str.find(healing_marker_);
+                if (idx != std::string::npos) {
+                    found_healing_marker = true;
+                    break;
+                }
+                path.push_back(key_str);
+                if (value.is_string()) {
+                    const std::string value_str = value;
+                    if (value_str.find(healing_marker_) != std::string::npos) {
+                        found_healing_marker = true;
+                        if (is_content_path(path)) {
+                            if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
+                                // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
+                                obj[key] = remove_unsupported_healings_and_dump_args(value);
+                            }
+                        }
+                        break;
+                    }
+                    obj[key] = value;
+                } else {
+                    obj[key] = remove_unsupported_healings_and_dump_args(value);
+                }
+                path.pop_back();
+            }
+            return obj;
+        }
+        if (j.is_array()) {
+            auto arr = json::array();
+            for (const auto & value : j) {
+                if (value.is_string()) {
+                    std::string str = value;
+                    auto idx = str.find(healing_marker_);
+                    if (idx != std::string::npos) {
+                        // Don't heal array values that aren't in the arguments.
+                        found_healing_marker = true;
+                        break;
+                    }
+                }
+                arr.push_back(remove_unsupported_healings_and_dump_args(value));
+            }
+            return arr;
+        }
+        return j;
+    };
+
+    auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
+    LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
+    return consume_json_result {
+        cleaned,
+        /* .is_partial = */ found_healing_marker,
+    };
+}
diff --git a/common/chat-parser.h b/common/chat-parser.h
new file mode 100644 (file)
index 0000000..b21b32b
--- /dev/null
@@ -0,0 +1,116 @@
+#pragma once
+
+#include "chat.h"
+#include "json-partial.h"
+#include "json.hpp"
+#include "regex-partial.h"
+
+#include <optional>
+#include <string>
+#include <vector>
+
+class common_chat_msg_partial_exception : public std::runtime_error {
+  public:
+    common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
+};
+
+class common_chat_msg_parser {
+    std::string input_;
+    bool is_partial_;
+    common_chat_syntax syntax_;
+    std::string healing_marker_;
+
+    size_t pos_ = 0;
+    common_chat_msg result_;
+
+  public:
+    common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
+    const std::string & input() const { return input_; }
+    size_t pos() const { return pos_; }
+    const std::string & healing_marker() const { return healing_marker_; }
+    const bool & is_partial() const { return is_partial_; }
+    const common_chat_msg & result() const { return result_; }
+
+    void move_to(size_t pos) {
+        if (pos > input_.size()) {
+            throw std::runtime_error("Invalid position!");
+        }
+        pos_ = pos;
+    }
+    void move_back(size_t n) {
+        if (pos_ < n) {
+            throw std::runtime_error("Can't move back that far!");
+        }
+        pos_ -= n;
+    }
+
+    // Get the substring of the input at the given range
+    std::string str(const common_string_range & rng) const;
+
+    // Appends to the result.content field
+    void add_content(const std::string & content);
+
+    // Appends to the result.reasoning_content field
+    void add_reasoning_content(const std::string & reasoning_content);
+
+    // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
+    bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
+
+    // Adds a tool call using the "name", "id" and "arguments" fields of the json object
+    bool add_tool_call(const nlohmann::ordered_json & tool_call);
+
+    // Adds an array of tool calls using their "name", "id" and "arguments" fields.
+    bool add_tool_calls(const nlohmann::ordered_json & arr);
+
+    void finish();
+
+    bool consume_spaces();
+
+    void consume_literal(const std::string & literal);
+
+    bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
+
+    std::string consume_rest();
+
+    struct find_regex_result {
+        std::string prelude;
+        std::vector<common_string_range> groups;
+    };
+
+    std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
+
+    bool try_consume_literal(const std::string & literal);
+
+    std::optional<find_regex_result> try_find_literal(const std::string & literal);
+
+    find_regex_result consume_regex(const common_regex & regex);
+
+    std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
+
+    std::optional<common_json> try_consume_json();
+    common_json consume_json();
+
+    struct consume_json_result {
+        nlohmann::ordered_json value;
+        bool is_partial;
+    };
+
+    /*
+        Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
+
+        By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
+        e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
+
+        But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
+        - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
+        - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
+    */
+    consume_json_result consume_json_with_dumped_args(
+        const std::vector<std::vector<std::string>> & args_paths = {},
+        const std::vector<std::vector<std::string>> & content_paths = {}
+    );
+    std::optional<consume_json_result> try_consume_json_with_dumped_args(
+        const std::vector<std::vector<std::string>> & args_paths = {},
+        const std::vector<std::vector<std::string>> & content_paths = {}
+    );
+};
index f138c7bcafcfafafe8707da23f0df046f77db75b..78af5eafa40c3bce4cdbcdf66a010794cc042d2f 100644 (file)
@@ -1,10 +1,21 @@
 #include "chat.h"
+#include "chat-parser.h"
+#include "common.h"
 #include "json-schema-to-grammar.h"
 #include "log.h"
+#include "json-partial.h"
 #include "minja/chat-template.hpp"
 #include "minja/minja.hpp"
+#include "regex-partial.h"
 
+#include <cstdio>
+#include <exception>
+#include <iostream>
 #include <optional>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
 
 static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
     auto time = std::chrono::system_clock::to_time_t(now);
@@ -15,6 +26,96 @@ static std::string format_time(const std::chrono::system_clock::time_point & now
     return res;
 }
 
+static std::string string_diff(const std::string & last, const std::string & current) {
+    if (last.empty()) {
+        return current;
+    }
+    if (!string_starts_with(current, last)) {
+        throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
+    }
+    return current.substr(last.size());
+}
+
+static bool has_content_or_tool_calls(const common_chat_msg & msg) {
+    return !msg.content.empty() || !msg.tool_calls.empty();
+}
+
+template <>
+json common_chat_msg::to_json_oaicompat() const
+{
+    json message {
+        {"role", "assistant"},
+    };
+    if (!reasoning_content.empty()) {
+        message["reasoning_content"] = reasoning_content;
+    }
+    if (content.empty() && !tool_calls.empty()) {
+        message["content"] = json();
+    } else {
+        message["content"] = content;
+    }
+    if (!tool_calls.empty()) {
+        auto arr = json::array();
+        for (const auto & tc : tool_calls) {
+            arr.push_back({
+                {"type", "function"},
+                {"function", {
+                    {"name", tc.name},
+                    {"arguments", tc.arguments},
+                }},
+                {"id", tc.id},
+                // // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
+                // // We only generate a random id for the ones that don't generate one by themselves
+                // // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
+                // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
+            });
+        }
+        message["tool_calls"] = arr;
+    }
+    return message;
+}
+
+std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
+    std::vector<common_chat_msg_diff> diffs;
+    // if (previous_msg.reasoning_content != current.reasoning_content) {
+    //     auto & diff = diffs.emplace_back();
+    //     diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content);
+    // }
+    if (previous_msg.content != new_msg.content) {
+        auto & diff = diffs.emplace_back();
+        diff.content_delta = string_diff(previous_msg.content, new_msg.content);
+    }
+
+    if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
+        throw std::runtime_error("Invalid diff: now finding less tool calls!");
+    }
+
+    if (!previous_msg.tool_calls.empty()) {
+        auto idx = previous_msg.tool_calls.size() - 1;
+        const auto & pref = previous_msg.tool_calls[idx];
+        const auto & newf = new_msg.tool_calls[idx];
+        if (pref.name != newf.name) {
+            throw std::runtime_error("Invalid diff: tool call mismatch!");
+        }
+        auto args_diff = string_diff(pref.arguments, newf.arguments);
+        if (!args_diff.empty() || pref.id != newf.id) {
+            auto & diff = diffs.emplace_back();
+            diff.tool_call_index = idx;
+            diff.tool_call_delta.name = newf.name;
+            if (pref.id != newf.id) {
+                diff.tool_call_delta.id = newf.id;
+            }
+            diff.tool_call_delta.arguments = args_diff;
+        }
+    }
+    for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
+        auto & diff = diffs.emplace_back();
+        diff.tool_call_index = idx;
+        diff.tool_call_delta = new_msg.tool_calls[idx];
+    }
+    return diffs;
+}
+
 typedef minja::chat_template common_chat_template;
 
 struct common_chat_templates {
@@ -32,7 +133,6 @@ struct templates_params {
     bool stream;
     std::string grammar;
     bool add_generation_prompt = true;
-    bool extract_reasoning     = true;
     std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
 };
 
@@ -277,6 +377,35 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
     return result;
 }
 
+template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
+    json delta = json::object();
+    // if (!diff.reasoning_content_delta.empty()) {
+    //     delta["reasoning_content"] = msg.reasoning_content;
+    // }
+    if (!diff.content_delta.empty()) {
+        delta["content"] = diff.content_delta;
+    }
+    if (diff.tool_call_index != std::string::npos) {
+        json function = json::object();
+        if (!diff.tool_call_delta.name.empty()) {
+            function["name"] = diff.tool_call_delta.name;
+        }
+        if (!diff.tool_call_delta.id.empty()) {
+            function["id"] = diff.tool_call_delta.id;
+        }
+        if (!diff.tool_call_delta.arguments.empty()) {
+            function["arguments"] = diff.tool_call_delta.arguments;
+        }
+        delta["tool_calls"] = json::array({
+            json {
+                {"index", diff.tool_call_index},
+                {"function", function}
+            }
+        });
+    }
+    return delta;
+}
+
 bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
     if (use_jinja) {
         try {
@@ -452,182 +581,121 @@ std::string common_chat_format_name(common_chat_format format) {
         case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
         case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
         case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
-        case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
         case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
         case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
-        case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
         case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
-        case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
         default:
             throw std::runtime_error("Unknown chat format");
     }
 }
 
-static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
-    // // https://json.nlohmann.me/features/parsing/sax_interface/
-    struct json_error_locator : public nlohmann::json_sax<json> {
-        std::size_t position;
-        bool found_error;
-
-        json_error_locator() : position(0), found_error(false) {}
-
-        bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
-            this->position = position - 1;
-            this->found_error = true;
-            return false;
+static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
+    std::string arguments;
+    if (builder.is_partial()) {
+        arguments = (json {{"code", code + builder.healing_marker()}}).dump();
+        auto idx = arguments.find(builder.healing_marker());
+        if (idx != std::string::npos) {
+            arguments.resize(idx);
         }
-        bool null() override { return true; } // NOLINT
-        bool boolean(bool) override { return true; } // NOLINT
-        bool number_integer(number_integer_t) override { return true; } // NOLINT
-        bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
-        bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
-        bool string(string_t &) override { return true; } // NOLINT
-        bool binary(binary_t &) override { return true; } // NOLINT
-        bool start_object(std::size_t) override { return true; } // NOLINT
-        bool key(string_t &) override { return true; } // NOLINT
-        bool end_object() override { return true; }
-        bool start_array(std::size_t) override { return true; } // NOLINT
-        bool end_array() override { return true; }
-    };
-    json_error_locator err_loc;
-    json::sax_parse(it, end, &err_loc);
-
-    std::string::const_iterator temptative_end;
-    if (err_loc.found_error) {
-        temptative_end = it + err_loc.position;
     } else {
-        temptative_end = end;
-    }
-    std::string json_sub {it, temptative_end};
-    try {
-        out = json::parse(json_sub);
-        it = temptative_end;
-        return true;
-    } catch (const std::exception &) {
-        return false;
-    }
-}
-
-static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
-    auto expected_it = expected.begin();
-    auto tmp_it = it;
-    while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
-        ++tmp_it;
-        ++expected_it;
-    }
-    if (expected_it == expected.end()) {
-        it = tmp_it;
-        return true;
-    }
-    return false;
-}
-
-static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
-    std::smatch match;
-    if (std::regex_match(it, end, match, expected)) {
-        it = match.suffix().first;
-        return match;
-    }
-    return std::nullopt;
-}
-
-static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
-    while (it != end && std::isspace(*it)) {
-        ++it;
+        arguments = (json {{"code", code}}).dump();
     }
+    return arguments;
 }
 
 /**
  * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
  * Aggregates the prefix, suffix and in-between text into the content.
  */
-static common_chat_msg parse_json_tool_calls(
-    const std::string& input,
-    const std::optional<std::regex> & trigger_opt,
-    const std::regex & function_regex,
-    const std::regex & close_regex,
-    bool allow_raw_python = false) {
-    std::smatch match;
-
-    common_chat_msg result;
-    result.role = "assistant";
-
-
-    auto end = input.end();
-    auto it = input.begin();
-
-    if (trigger_opt) {
-        if (!std::regex_search(it, end, match, *trigger_opt)) {
-            result.content = input;
-            return result;
-        }
-        result.content = match.prefix().str();
-        it = match.suffix().first;
-    }
-
-    while (it != end) {
-        std::sregex_iterator rend;
-        std::sregex_iterator rit(it, end, function_regex);
-        if (rit == rend) {
-            result.content += std::string(it, end);
+static void parse_json_tool_calls(
+    common_chat_msg_parser & builder,
+    const std::optional<common_regex> & block_open,
+    const std::optional<common_regex> & function_regex_start_only,
+    const std::optional<common_regex> & function_regex,
+    const common_regex & close_regex,
+    const std::optional<common_regex> & block_close,
+    bool allow_raw_python = false,
+    const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name = nullptr) {
+
+    auto parse_tool_calls = [&]() {
+        size_t from = std::string::npos;
+        auto first = true;
+        while (true) {
+            auto res = function_regex_start_only && first
+                ? builder.try_consume_regex(*function_regex_start_only)
+                : function_regex
+                    ? builder.try_find_regex(*function_regex, from)
+                    : std::nullopt;
+            if (res) {
+                std::string name;
+                if (get_function_name) {
+                    name = get_function_name(*res);
+                } else {
+                    GGML_ASSERT(res->groups.size() == 2);
+                    name = builder.str(res->groups[1]);
+                }
+                first = false;
+                if (name.empty()) {
+                    // get_function_name signalled us that we should skip this match and treat it as content.
+                    from = res->groups[0].begin + 1;
+                    continue;
+                }
+                from = std::string::npos;
+
+                builder.add_content(res->prelude);
+                auto maybe_raw_python = name == "python" && allow_raw_python;
+                if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
+                    if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
+                        if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
+                            throw common_chat_msg_partial_exception("incomplete tool call");
+                        }
+                        builder.consume_regex(close_regex);
+                    }
+                    continue;
+                }
+                if (maybe_raw_python) {
+                    auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
+                    if (!builder.add_tool_call(name, "", arguments)) {
+                        throw common_chat_msg_partial_exception("incomplete tool call");
+                    }
+                    return;
+                }
+                throw common_chat_msg_partial_exception("incomplete tool call");
+            }
             break;
         }
-        auto name = rit->str(1);
-        result.content += std::string(it, rit->prefix().second);
-        it = rit->suffix().first;
-
-        json arguments;
-        if (parse_json(it, end, arguments)) {
-            if (!std::regex_search(it, end, match, close_regex)) {
-                throw std::runtime_error("Malformed input, missing closing pattern: " + input);
-            }
-            it = match.suffix().first;
-            result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
-        } else {
-            if (allow_raw_python && name == "python") {
-                result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
-                break;
-            }
-            throw std::runtime_error("Failed to parse json tool call arguments: " + input);
+        if (block_close) {
+            builder.consume_regex(*block_close);
         }
-    }
-
-    if (!result.tool_calls.empty()) {
-        if (!string_strip(result.content).empty()) {
-            LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
+        builder.consume_spaces();
+        builder.add_content(builder.consume_rest());
+    };
+    if (block_open) {
+        if (auto res = builder.try_find_regex(*block_open)) {
+            builder.add_content(res->prelude);
+            parse_tool_calls();
+        } else {
+            builder.add_content(builder.consume_rest());
         }
-        result.content = "";
+    } else {
+        parse_tool_calls();
     }
-    return result;
 }
 
-static common_chat_tool_call process_tool_call(const json & tool_call) {
-    const auto & arguments = tool_call.at("arguments");
-    return {
-        /* .name = */ tool_call.at("name"),
-        /* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
-        /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
-    };
-}
-static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
-    auto content_end = input.find(prefix);
-    size_t tc_start = std::string::npos;
-
-    common_chat_msg result;
-    result.role = "assistant";
-    if (content_end == std::string::npos) {
-        result.content = input;
-    } else {
-        tc_start = content_end + prefix.size() - rstrip_prefix;
-        result.content = input.substr(0, content_end);
-        auto tool_calls = json::parse(input.substr(tc_start));
-        for (const auto & tool_call : tool_calls) {
-            result.tool_calls.emplace_back(process_tool_call(tool_call));
+static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
+    static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
+    if (auto res = builder.try_find_regex(prefix)) {
+        builder.add_content(res->prelude);
+        builder.move_back(rstrip_prefix);
+        auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
+        if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
+            throw common_chat_msg_partial_exception("incomplete tool call array");
         }
+    } else {
+        builder.add_content(builder.consume_rest());
     }
-    return result;
 }
 
 static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
@@ -754,29 +822,32 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
     data.format = COMMON_CHAT_FORMAT_GENERIC;
     return data;
 }
-static common_chat_msg common_chat_parse_generic(const std::string & input) {
-    json data = json::parse(input);
-    common_chat_msg result;
-    result.role = "assistant";
-    if (data.contains("tool_calls")) {
-        for (const auto & tool_call : data.at("tool_calls")) {
-            result.tool_calls.push_back({
-                tool_call.at("name"),
-                tool_call.at("arguments").dump(),
-                tool_call.contains("id") ? tool_call.at("id") : "",
-            });
+static void common_chat_parse_generic(common_chat_msg_parser & builder) {
+    static const std::vector<std::vector<std::string>> content_paths = {
+        {"response"},
+    };
+    static const std::vector<std::vector<std::string>> args_paths = {
+        {"tool_call", "arguments"},
+        {"tool_calls", "arguments"},
+    };
+    auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
+    if (data.value.contains("tool_calls")) {
+        if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
+            throw common_chat_msg_partial_exception("incomplete tool calls");
         }
-    } else if (data.contains("tool_call")) {
-        result.tool_calls.push_back({
-            data.at("tool_call").at("name"),
-            data.at("tool_call").at("arguments").dump(),
-            /* id= */ "",
-        });
-    } else if (data.contains("response")) {
-        const auto & response = data.at("response");
-        result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
+    } else if (data.value.contains("tool_call")) {
+        if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
+            throw common_chat_msg_partial_exception("incomplete tool call");
+        }
+    } else if (data.value.contains("response")) {
+        const auto & response = data.value.at("response");
+        builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
+        if (data.is_partial) {
+            throw common_chat_msg_partial_exception("incomplete response");
+        }
+    } else {
+        throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
     }
-    return result;
 }
 
 static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -823,12 +894,33 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
     data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
     return data;
 }
-static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
-    return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
+static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
+    static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
+    parse_prefixed_json_tool_call_array(builder, prefix);
 }
 
 static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
+
+    auto adjusted_messages = json::array();
+    for (const auto & msg : inputs.messages) {
+        auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
+        auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
+        if (has_reasoning_content && has_tool_calls) {
+            auto adjusted_message = msg;
+            adjusted_message["tool_plan"] = msg.at("reasoning_content");
+            adjusted_message.erase("reasoning_content");
+            adjusted_messages.push_back(adjusted_message);
+        } else {
+            adjusted_messages.push_back(msg);
+        }
+    }
+    data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
+    data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
+    if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
+        data.thinking_forced_open = true;
+    }
+
     data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         auto schemas = json::array();
@@ -859,11 +951,16 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
         if (!inputs.parallel_tool_calls) {
             schema["maxItems"] = 1;
         }
-        builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
+        builder.add_rule("root",
+            std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") +
+            "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
     });
     data.grammar_triggers.push_back({
-        COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
-        "<|START_ACTION|>",
+        COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+        // If thinking_forced_open, then we capture the </think> tag in the grammar,
+        // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+        std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") +
+            "(<\\|START_ACTION\\|>)[\\s\\S]*"
     });
     data.preserved_tokens = {
         "<|START_ACTION|>",
@@ -873,61 +970,45 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
         "<|START_THINKING|>",
         "<|END_THINKING|>",
     };
-    auto adjusted_messages = json::array();
-    for (const auto & msg : inputs.messages) {
-        auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
-        auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
-        if (has_reasoning_content && has_tool_calls) {
-            auto adjusted_message = msg;
-            adjusted_message["tool_plan"] = msg.at("reasoning_content");
-            adjusted_message.erase("reasoning_content");
-            adjusted_messages.push_back(adjusted_message);
-        } else {
-            adjusted_messages.push_back(msg);
-        }
-    }
-    data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
-    data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
     return data;
 }
-static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
-    static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
-    static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
-    static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
-
-    std::smatch match;
-
-    common_chat_msg result;
-    result.role = "assistant";
 
-    std::string rest = input;
-
-    if (std::regex_match(rest, match, thought_regex)) {
-        if (extract_reasoning) {
-            result.reasoning_content = match[2].str();
-        } else if (!match[2].str().empty()) {
-            // Let the unparsed thinking tags through in content only if their insides aren't empty.
-            result.content = match[1].str();
+static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
+    builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
+
+    static const common_regex start_action_regex("<\\|START_ACTION\\|>");
+    static const common_regex end_action_regex("<\\|END_ACTION\\|>");
+    static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
+    static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
+
+    if (auto res = builder.try_find_regex(start_action_regex)) {
+        // If we didn't extract thoughts, prelude includes them.
+        builder.add_content(res->prelude);
+        auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
+        for (const auto & tool_call : tool_calls.value) {
+            std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
+            std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
+            std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
+            if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
+                throw common_chat_msg_partial_exception("incomplete tool call");
+            }
         }
-        rest = match[3].str();
-    }
-    if (std::regex_match(rest, match, action_regex)) {
-        auto actions_str = match[1].str();
-        auto actions = json::parse(actions_str);
-        for (const auto & action : actions) {
-            result.tool_calls.push_back({
-                /* .name = */      action.at("tool_name"),
-                /* .arguments = */ action.at("parameters").dump(),
-                /* .id = */        action.at("tool_call_id"),
-            });
+        if (tool_calls.is_partial) {
+            throw common_chat_msg_partial_exception("incomplete tool call");
+        }
+        builder.consume_regex(end_action_regex);
+    } else if (auto res = builder.try_find_regex(start_response_regex)) {
+        // If we didn't extract thoughts, prelude includes them.
+        builder.add_content(res->prelude);
+        if (auto res = builder.try_find_regex(end_response_regex)) {
+            builder.add_content(res->prelude);
+        } else {
+            builder.add_content(builder.consume_rest());
+            throw common_chat_msg_partial_exception(end_response_regex.str());
         }
-    } else if (std::regex_match(rest, match, response_regex)) {
-        auto response = match[1].str();
-        result.content += response;
     } else {
-        result.content += rest;
+        builder.add_content(builder.consume_rest());
     }
-    return result;
 }
 
 static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
@@ -1004,8 +1085,8 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
             });
             // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
             data.grammar_triggers.push_back({
-                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
-                "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+                "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*",
             });
             if (!builtin_tools.empty()) {
                 data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
@@ -1028,42 +1109,86 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
     });
     return data;
 }
-static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
-    // TODO: tighten & simplify the parser, don't accept leading text context.
-    static const std::regex function_regex(
+static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
+    static const common_regex function_regex(
         "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
-    static const std::regex close_regex("\\}\\s*");
-    static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
+    static const common_regex close_regex("\\}\\s*");
+
+    static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
+    static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
 
     if (with_builtin_tools) {
-        std::smatch match;
-        if (std::regex_match(input, match, builtin_call_regex)) {
-            try {
-                auto name = match[1].str();
-                auto arg_name = match[2].str();
-                auto arg_value_str = match[3].str();
-                auto arg_value = json::parse(arg_value_str);
-
-                common_chat_msg msg;
-                msg.role = "assistant";
-                msg.tool_calls.push_back({
-                    /* .name = */ name,
-                    /* .arguments = */ (json {
-                        {arg_name, arg_value},
-                    }).dump(),
-                    /* .id = */ "",
-                });
-                return msg;
-            } catch (const std::exception & e) {
-                LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
+        static const common_regex builtin_call_regex("<\\|python_tag\\|>");
+        if (auto res = builder.try_find_regex(builtin_call_regex)) {
+            builder.add_content(res->prelude);
+
+            auto fun_res = builder.consume_regex(function_name_regex);
+            auto function_name = builder.str(fun_res.groups[1]);
+
+            common_healing_marker healing_marker;
+            json args = json::object();
+            while (true) {
+                if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
+                    auto arg_name = builder.str(arg_res->groups[1]);
+                    auto partial = builder.consume_json();
+                    args[arg_name] = partial.json;
+                    healing_marker.marker = partial.healing_marker.marker;
+                    healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
+                    builder.consume_spaces();
+                    if (!builder.try_consume_literal(",")) {
+                        break;
+                    }
+                } else {
+                    break;
+                }
             }
+            builder.consume_literal(")");
+            builder.consume_spaces();
+
+            auto arguments = args.dump();
+            if (!builder.add_tool_call(function_name, "", arguments)) {
+                throw common_chat_msg_partial_exception("Incomplete tool call");
+            }
+            return;
         }
     }
-    return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
+    parse_json_tool_calls(
+        builder,
+        /* block_open= */ std::nullopt,
+        /* function_regex_start_only= */ function_regex,
+        /* function_regex= */ std::nullopt,
+        close_regex,
+        std::nullopt);
+
 }
 
 static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
+    auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
+
+    // Hacks to fix the official (broken) prompt.
+    // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
+    // until the official template is fixed.
+    if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
+        // Don't leave the chat dangling after tool results
+        if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
+            prompt += "<|end▁of▁sentence|>";
+            if (inputs.add_generation_prompt) {
+                prompt += "<|Assistant|>";
+            }
+        }
+        // Fix up tool call delta example added by Minja
+        prompt = std::regex_replace(
+            prompt,
+            std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
+            "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
+    }
+    data.prompt = prompt;
+    data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
+    if (string_ends_with(data.prompt, "<think>\n")) {
+        data.thinking_forced_open = true;
+    }
+
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
         data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -1074,21 +1199,25 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
                 auto parameters = function.at("parameters");
                 builder.resolve_refs(parameters);
                 tool_rules.push_back(builder.add_rule(name + "-call",
-                    "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
+                    "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n"
                     "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
                     "\"```<|tool▁call▁end|>\""));
             });
             // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
             // so we accept common variants (then it's all constrained)
             builder.add_rule("root",
-                "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
+                std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
+                "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) "
                 "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
                 "\"<|tool▁calls▁end|>\""
                 " space");
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
-            data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
+            data.grammar_triggers.push_back({
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+                // If thinking_forced_open, then we capture the </think> tag in the grammar,
+                // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+                std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") +
+                    "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*"
+            });
             data.preserved_tokens = {
                 "<think>",
                 "</think>",
@@ -1100,65 +1229,23 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
             };
         });
     }
-    auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
-
-    // Hacks to fix the official (broken) prompt.
-    // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
-    // until the official template is fixed.
-    if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
-        // Don't leave the chat dangling after tool results
-        if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
-            prompt += "<|end▁of▁sentence|>";
-            if (inputs.add_generation_prompt) {
-                prompt += "<|Assistant|>";
-            }
-        }
-        // Fix up tool call delta example added by Minja
-        prompt = std::regex_replace(
-            prompt,
-            std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
-            "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
-    }
-    data.prompt = prompt;
-    data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
     return data;
 }
-static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
-    std::smatch match;
-    static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
-    if (std::regex_match(input, match, reasoning_content_regex)) {
-        auto rest = match[3].str();
-        auto msg = rest_parser(rest);
-        auto reasoning_content = string_strip(match[2].str());
-        if (extract_reasoning) {
-            msg.reasoning_content = reasoning_content;
-        } else if (!reasoning_content.empty()) {
-            std::ostringstream content;
-            content << "<think>" << reasoning_content << "</think>" << msg.content;
-            msg.content = content.str();
-        }
-        return msg;
-    }
-    return rest_parser(input);
-}
-static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
-    return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
-        static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
-        static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
-        static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
-
-        common_chat_msg msg;
-        msg.role = "assistant";
-        std::smatch match;
-        if (std::regex_search(input, match, tool_calls_regex)) {
-            auto tool_calls = match[1].str();
-            auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
-            msg.tool_calls = std::move(msg2.tool_calls);
-        } else {
-            msg.content = input;
-        }
-        return msg;
-    });
+static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
+    builder.try_parse_reasoning("<think>", "</think>");
+
+    static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
+    static const common_regex tool_calls_end("<|tool▁calls▁end|>");
+    static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n");
+    static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
+
+    parse_json_tool_calls(
+        builder,
+        /* block_open= */ tool_calls_begin,
+        /* function_regex_start_only= */ std::nullopt,
+        function_regex,
+        close_regex,
+        tool_calls_end);
 }
 
 static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1206,13 +1293,15 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
     }
     return data;
 }
-static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
-    return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
+static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
+    static const common_regex prefix(regex_escape(" functools["));
+    parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
 }
 
 static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
     // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
+    // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
@@ -1226,24 +1315,21 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
                 builder.resolve_refs(parameters);
+                std::string args_pattern = "[\\s\\S]*";
                 auto args_rule = builder.add_schema(name + "-args", parameters);
-                first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
-                subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
-                data.grammar_triggers.push_back({
-                    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
-                    regex_escape(name + "\n"),
-                });
-                data.grammar_triggers.push_back({
-                    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
-                    regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
-                });
-                data.grammar_triggers.push_back({
-                    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
-                    regex_escape(">>>" + name + "\n"),
-                });
+                if (name == "python") {
+                    args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*");
+                } else {
+                    args_pattern = "\\{" + args_pattern;
+                }
+                auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule);
+                first_tool_rules.push_back(call_rule);
+                if (inputs.parallel_tool_calls) {
+                    subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule));
+                }
                 data.grammar_triggers.push_back({
-                    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
-                    ">>>assistant<|end_header_id|>\n" + name,
+                    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+                    "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern,
                 });
             });
             data.preserved_tokens = {
@@ -1261,40 +1347,33 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
     }
     return data;
 }
-
-static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
-    static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
-    static const std::regex close_regex(R"($|(?=>>>))");
-
-    std::string content;
-    auto it = input.begin();
-    const auto end = input.end();
-
-    if (parse_literal(it, end, "all\n")) {
-        std::smatch match;
-        if (std::regex_search(it, end, match, function_regex)) {
-            auto fun_it = match.prefix().second;
-            content = std::string(it, fun_it);
-            it = fun_it;
-        } else {
-            common_chat_msg res;
-            res.role = "assistant";
-            res.content = std::string(it, end);
-            return res;
-        }
-    }
-    // TODO: tighten & simplify.
-    try {
-        auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
-        res.content = content + res.content;
-        return res;
-    } catch (const std::exception & e) {
-        LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
-        common_chat_msg res;
-        res.role = "assistant";
-        res.content = input;
-        return res;
-    }
+static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
+    static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
+    static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
+    static const common_regex close_regex(R"(\s*)");
+
+    parse_json_tool_calls(
+        builder,
+        std::nullopt,
+        function_regex_start_only,
+        function_regex,
+        close_regex,
+        std::nullopt,
+        /* allow_raw_python= */ true,
+        /* get_function_name= */ [&](const auto & res) -> std::string {
+            auto at_start = res.groups[0].begin == 0;
+            auto name = builder.str(res.groups[1]);
+            if (!name.empty() && name.back() == '{') {
+                // Unconsume the opening brace '{' to ensure the JSON parsing goes well.
+                builder.move_back(1);
+            }
+            auto idx = name.find_last_not_of("\n{");
+            name = name.substr(0, idx + 1);
+            if (at_start && name == "all") {
+                return "";
+            }
+            return name;
+        });
 }
 
 static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1355,35 +1434,44 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
     // TODO: if (has_raw_python)
     return data;
 }
-static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
+static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
     // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
-    static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
-    std::smatch match;
-    if (std::regex_search(input, match, python_tag_regex)) {
-        auto code = match[1].str();
-        common_chat_msg msg;
-        msg.role = "assistant";
-        msg.content = match.prefix().str();
-        msg.tool_calls.push_back({
-            /* .name = */ "python",
-            /* .arguments = */ (json {{"code", code}}).dump(),
-            /* .id = */ "",
-        });
-        return msg;
+    static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
+
+    if (auto res = builder.try_find_regex(python_tag_regex)) {
+        builder.add_content(res->prelude);
+        auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
+        builder.add_tool_call("python", "", arguments);
+        return;
     }
-    static const std::regex function_regex(R"(<function=(\w+)>)");
-    static const std::regex close_regex(R"(</function>)");
-    // TODO: tighten & simplify.
-    return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
+
+    static const common_regex function_regex(R"(<function=(\w+)>)");
+    static const common_regex close_regex(R"(</function>)");
+
+    parse_json_tool_calls(
+        builder,
+        /* block_open= */ std::nullopt,
+        /* function_regex_start_only= */ std::nullopt,
+        function_regex,
+        close_regex,
+        std::nullopt);
 }
 
 static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
+
+    data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
+    data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
+    if (string_ends_with(data.prompt, "<think>\n")) {
+        data.thinking_forced_open = true;
+    }
+
     // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
     data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
         std::vector<std::string> tool_call_alts;
+        std::vector<std::string> escaped_names;
         foreach_function(inputs.tools, [&](const json & tool) {
             const auto & function = tool.at("function");
             std::string name = function.at("name");
@@ -1412,6 +1500,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
                 COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
                 "<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
             });
+            escaped_names.push_back(escaped_name);
         });
         auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
         std::vector<std::string> alt_tags {
@@ -1430,13 +1519,23 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
         tool_call_alts.push_back(
             "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
         auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
-        builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
-        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
-        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
+        builder.add_rule("root",
+            std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
+            (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
         // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
         data.grammar_triggers.push_back({
-            COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
-            "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
+            COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+            // If thinking_forced_open, then we capture the </think> tag in the grammar,
+            // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+            std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
+                "(\\s*"
+                "(?:<tool_call>"
+                "|<function"
+                "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
+                 "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
+                ")"
+                ")[\\s\\S]*"
+            ),
         });
         data.preserved_tokens = {
             "<think>",
@@ -1460,124 +1559,84 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
         };
     });
 
-    data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
-    data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
     return data;
 }
-static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
-    return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
-        static const std::regex open_regex(
-            "(?:"
-            "(```(?:xml|json)?\\n\\s*)?"         // match 1 (block_start)
-            "(<tool_call>"                   // match 2 (open_tag)
-            "|<function_call>"
-            "|<tool>"
-            "|<tools>"
-            "|<response>"
-            "|<json>"
-            "|<xml>"
-            "|<JSON>"
+static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
+    builder.try_parse_reasoning("<think>", "</think>");
+
+    static const common_regex open_regex(
+        "(?:"
+            "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
+            "("                          // match 2 (open_tag)
+                "<tool_call>"
+                "|<function_call>"
+                "|<tool>"
+                "|<tools>"
+                "|<response>"
+                "|<json>"
+                "|<xml>"
+                "|<JSON>"
             ")?"
-            "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)"    // match 3 (named tool call + rest)
-            ")"
-            "|"
-            "(?:<function=([^>]+)>"            // match 4 (function name)
-            "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
-            "([\\s\\S]*)"                   // match 6 (function arguments + rest)})"
-        );
-
-        try {
-            common_chat_msg msg;
-            msg.role = "assistant";
+            "(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
+        ")"
+        "|<function=([^>]+)>"            // match 4 (function name)
+        "|<function name=\"([^\"]+)\">"  // match 5 (function name again)
+    );
 
-            std::string::const_iterator it = input.begin();
-            const std::string::const_iterator end = input.end();
-            std::smatch match;
+    if (auto res = builder.try_find_regex(open_regex)) {
+        builder.add_content(res->prelude);
 
-            while (it != end) {
-                if (std::regex_search(it, end, match, open_regex)) {
-                    // Add content before the match
-                    msg.content += std::string(it, match[0].first);
+        const auto & block_start = res->groups[1];
+        std::string block_end = block_start.empty() ? "" : "```";
 
-                    auto block_start = match[1].str();
-                    std::string block_end = block_start.empty() ? "" : "```";
+        const auto & open_tag = res->groups[2];
+        std::string close_tag;
 
-                    auto open_tag = match[2].str();
-                    std::string close_tag;
+        if (!res->groups[3].empty()) {
+            builder.move_to(res->groups[3].begin);
+            close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
 
-                    if (match[3].matched) {
-                        close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
-                        auto json_it = match[3].first;
-                        json tool_call;
-                        if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
+            if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
+                if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
+                    throw common_chat_msg_partial_exception("incomplete tool call");
+                }
+                builder.consume_spaces();
+                builder.consume_literal(close_tag);
+                builder.consume_spaces();
+                if (!block_end.empty()) {
+                    builder.consume_literal(block_end);
+                    builder.consume_spaces();
+                }
+                builder.add_content(builder.consume_rest());
+            } else {
+                throw common_chat_msg_partial_exception("failed to parse tool call");
+            }
+        } else {
+            auto function_name = builder.str(res->groups[4]);
+            if (function_name.empty()) {
+                function_name = builder.str(res->groups[5]);
+            }
+            GGML_ASSERT(!function_name.empty());
 
-                            msg.tool_calls.emplace_back(process_tool_call(tool_call));
-                            it = json_it;  // Move iterator past parsed JSON
+            close_tag = "</function>";
 
-                            // Handle close tags
-                            consume_spaces(it, end);
-                            if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
-                                throw std::runtime_error("Failed to parse closing tag");
-                            }
-                            consume_spaces(it, end);
-                            if (!block_end.empty() && !parse_literal(it, end, block_end)) {
-                                throw std::runtime_error("Failed to parse block end");
-                            }
-                            consume_spaces(it, end);
-                        } else {
-                            // Not a valid tool call, treat as content
-                            msg.content += std::string(match[0].first, match[0].second);
-                            it = match[0].second;
-                        }
-                    } else {
-                        auto function_name = match[4].str();
-                        if (function_name.empty()) {
-                            function_name = match[5].str();
-                        }
-                        GGML_ASSERT(!function_name.empty());
-
-                        close_tag = "</function>";
-                        // Start parsing from after the opening tags
-                        auto json_it = match[6].first;
-                        json arguments;
-                        if (parse_json(json_it, end, arguments)) {
-                            msg.tool_calls.emplace_back(process_tool_call({
-                                {"name", function_name},
-                                {"arguments", arguments},
-                            }));
-                            it = json_it;  // Move iterator past parsed JSON
-
-                            // Handle close tags
-                            consume_spaces(it, end);
-                            if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
-                                throw std::runtime_error("Failed to parse closing tag");
-                            }
-                            consume_spaces(it, end);
-                            if (!block_end.empty() && !parse_literal(it, end, block_end)) {
-                                throw std::runtime_error("Failed to parse block end");
-                            }
-                            consume_spaces(it, end);
-                        } else {
-                            // Not a valid tool call, treat as content
-                            msg.content += std::string(match[0].first, match[0].second);
-                            it = match[0].second;
-                        }
-                    }
-                } else {
-                    // Add remaining content
-                    msg.content += std::string(it, end);
-                    break;
+            if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
+                if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
+                    throw common_chat_msg_partial_exception("incomplete tool call");
+                }
+                builder.consume_spaces();
+                builder.consume_literal(close_tag);
+                builder.consume_spaces();
+                if (!block_end.empty()) {
+                    builder.consume_literal(block_end);
+                    builder.consume_spaces();
                 }
             }
-            return msg;
-        } catch (const std::exception & e) {
-            LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
-            common_chat_msg msg;
-            msg.role = "assistant";
-            msg.content = input;
-            return msg;
+            builder.add_content(builder.consume_rest());
         }
-    });
+    } else {
+        builder.add_content(builder.consume_rest());
+    }
 }
 
 static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1609,7 +1668,6 @@ static common_chat_params common_chat_templates_apply_jinja(
     const auto & caps = tmpl.original_caps();
     params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
     params.add_generation_prompt = inputs.add_generation_prompt;
-    params.extract_reasoning = inputs.extract_reasoning;
     params.tool_choice = inputs.tool_choice;
     params.grammar = inputs.grammar;
     params.now = inputs.now;
@@ -1758,44 +1816,64 @@ common_chat_params common_chat_templates_apply(
         : common_chat_templates_apply_legacy(tmpls, inputs);
 }
 
-static common_chat_msg common_chat_parse_content_only(const std::string & input) {
-    common_chat_msg msg;
-    msg.role = "assistant";
-    msg.content = input;
-    return msg;
+static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
+    builder.add_content(builder.consume_rest());
 }
 
-common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
+static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
+    LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format).c_str(), builder.input().c_str());
+
     switch (format) {
         case COMMON_CHAT_FORMAT_CONTENT_ONLY:
-            return common_chat_parse_content_only(input);
+            common_chat_parse_content_only(builder);
+            break;
         case COMMON_CHAT_FORMAT_GENERIC:
-            return common_chat_parse_generic(input);
+            common_chat_parse_generic(builder);
+            break;
         case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
-            return common_chat_parse_mistral_nemo(input);
+            common_chat_parse_mistral_nemo(builder);
+            break;
         case COMMON_CHAT_FORMAT_LLAMA_3_X:
-            return common_chat_parse_llama_3_1(input);
+            common_chat_parse_llama_3_1(builder);
+            break;
         case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
-            return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
+            common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
+            break;
         case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
-            return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
-        case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
-            return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
+            common_chat_parse_deepseek_r1(builder);
+            break;
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
-            return common_chat_parse_functionary_v3_2(input);
+            common_chat_parse_functionary_v3_2(builder);
+            break;
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
-            return common_chat_parse_functionary_v3_1_llama_3_1(input);
+            common_chat_parse_functionary_v3_1_llama_3_1(builder);
+            break;
         case COMMON_CHAT_FORMAT_HERMES_2_PRO:
-            return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
-        case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
-            return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
+            common_chat_parse_hermes_2_pro(builder);
+            break;
         case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
-            return common_chat_parse_firefunction_v2(input);
+            common_chat_parse_firefunction_v2(builder);
+            break;
         case COMMON_CHAT_FORMAT_COMMAND_R7B:
-            return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
-        case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
-            return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
+            common_chat_parse_command_r7b(builder);
+            break;
         default:
             throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
     }
+    builder.finish();
+}
+
+common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
+    common_chat_msg_parser builder(input, is_partial, syntax);
+    try {
+        common_chat_parse(builder, syntax.format);
+    } catch (const common_chat_msg_partial_exception & ex) {
+        LOG_DBG("Partial parse: %s\n", ex.what());
+        if (!is_partial) {
+            throw std::runtime_error(ex.what());
+        }
+    }
+    auto msg = builder.result();
+    LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
+    return msg;
 }
index d26a09c2f7c4fdc8b82ed53d5272bc0ed57d935f..ce926777ebe91bd99c2d1fe7791e43a10262a657 100644 (file)
@@ -3,6 +3,7 @@
 #pragma once
 
 #include "common.h"
+#include <functional>
 #include <chrono>
 #include <string>
 #include <vector>
@@ -13,11 +14,19 @@ struct common_chat_tool_call {
     std::string name;
     std::string arguments;
     std::string id;
+
+    bool operator==(const common_chat_tool_call & other) const {
+        return name == other.name && arguments == other.arguments && id == other.id;
+    }
 };
 
 struct common_chat_msg_content_part {
     std::string type;
     std::string text;
+
+    bool operator==(const common_chat_msg_content_part & other) const {
+        return type == other.type && text == other.text;
+    }
 };
 
 struct common_chat_msg {
@@ -28,6 +37,51 @@ struct common_chat_msg {
     std::string reasoning_content;
     std::string tool_name;
     std::string tool_call_id;
+
+    template <class T> T to_json_oaicompat() const;
+
+    bool empty() const {
+        return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
+    }
+    void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
+        for (auto i = 0u; i < tool_calls.size(); i++) {
+            if (ids_cache.size() <= i) {
+                auto id = tool_calls[i].id;
+                if (id.empty()) {
+                    id = gen_tool_call_id();
+                }
+                ids_cache.push_back(id);
+            }
+            tool_calls[i].id = ids_cache[i];
+        }
+    }
+    bool operator==(const common_chat_msg & other) const {
+        return role == other.role
+            && content == other.content
+            && content_parts == other.content_parts
+            && tool_calls == other.tool_calls
+            && reasoning_content == other.reasoning_content
+            && tool_name == other.tool_name
+            && tool_call_id == other.tool_call_id;
+    }
+    bool operator!=(const common_chat_msg & other) const {
+        return !(*this == other);
+    }
+};
+
+struct common_chat_msg_diff {
+    // std::string reasoning_content_delta;
+    std::string content_delta;
+    size_t tool_call_index = std::string::npos;
+    common_chat_tool_call tool_call_delta;
+
+    static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
+
+    bool operator==(const common_chat_msg_diff & other) const {
+        return content_delta == other.content_delta
+        && tool_call_index == other.tool_call_index
+        && tool_call_delta == other.tool_call_delta;
+    }
 };
 
 struct common_chat_tool {
@@ -49,14 +103,11 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_LLAMA_3_X,
     COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
     COMMON_CHAT_FORMAT_DEEPSEEK_R1,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
     COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
     COMMON_CHAT_FORMAT_HERMES_2_PRO,
-    COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
     COMMON_CHAT_FORMAT_COMMAND_R7B,
-    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
 
     COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
 };
@@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
     std::vector<common_chat_tool> tools;
     common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
     bool parallel_tool_calls = false;
-    bool extract_reasoning     = true;
+    common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
     std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
 };
 
@@ -80,11 +131,20 @@ struct common_chat_params {
     std::string                         prompt;
     std::string                         grammar;
     bool                                grammar_lazy = false;
+    bool                                thinking_forced_open = false;
     std::vector<common_grammar_trigger> grammar_triggers;
     std::vector<std::string>            preserved_tokens;
     std::vector<std::string>            additional_stops;
 };
 
+struct common_chat_syntax {
+    common_chat_format       format                = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    common_reasoning_format  reasoning_format      = COMMON_REASONING_FORMAT_NONE;
+    // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
+    bool                     reasoning_in_content  = false;
+    bool                     thinking_forced_open  = false;
+};
+
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
 
@@ -122,7 +182,7 @@ std::string common_chat_format_example(
     bool use_jinja);
 
 std::string               common_chat_format_name(common_chat_format format);
-common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);
+common_chat_msg           common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
 
 common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
 
@@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
 // T can be std::string containing JSON or nlohmann::ordered_json
 template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
 template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
+
+template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
index aced2a0166ad60105863ec69de77044e159fa02f..f0c52c314b744547b219c9481a64b7ed8fdc3035 100644 (file)
@@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
     COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
     COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
     COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
-    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
 };
 
 struct common_grammar_trigger {
diff --git a/common/json-partial.cpp b/common/json-partial.cpp
new file mode 100644 (file)
index 0000000..7591a8e
--- /dev/null
@@ -0,0 +1,255 @@
+#include <json-partial.h>
+#include "ggml.h"
+#include "log.h"
+#include <string>
+
+#include <json.hpp>
+
+using json = nlohmann::ordered_json;
+
+enum common_json_stack_element_type {
+    COMMON_JSON_STACK_ELEMENT_OBJECT,
+    COMMON_JSON_STACK_ELEMENT_KEY,
+    COMMON_JSON_STACK_ELEMENT_ARRAY,
+};
+
+struct common_json_stack_element {
+    common_json_stack_element_type type;
+    std::string key;
+};
+
+bool common_json_parse(
+    const std::string & input,
+    const std::string & healing_marker,
+    common_json & out)
+{
+    std::string::const_iterator it = input.begin();
+    const auto end = input.end();
+    return common_json_parse(it, end, healing_marker, out);
+}
+
+bool common_json_parse(
+    std::string::const_iterator & it,
+    const std::string::const_iterator & end,
+    const std::string & healing_marker,
+    common_json & out)
+{
+    // // https://json.nlohmann.me/features/parsing/sax_interface/
+    struct json_error_locator : public nlohmann::json_sax<json> {
+        std::size_t position;
+        bool found_error;
+        std::string last_token;
+        std::string exception_message;
+        std::vector<common_json_stack_element> stack;
+
+        json_error_locator() : position(0), found_error(false) {}
+
+        bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
+            this->position = position - 1;
+            this->found_error = true;
+            this->last_token = last_token;
+            this->exception_message = ex.what();
+            return false;
+        }
+        void close_value() {
+            if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
+                stack.pop_back();
+            }
+        }
+        bool null() override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool boolean(bool) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_integer(number_integer_t) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_unsigned(number_unsigned_t) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_float(number_float_t, const string_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool string(string_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool binary(binary_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool start_object(std::size_t) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
+            return true;
+        }
+        bool end_object() override {
+            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
+            stack.pop_back();
+            close_value();
+            return true;
+        }
+        bool key(string_t & key) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
+            return true;
+        }
+        bool start_array(std::size_t) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
+            return true;
+        }
+        bool end_array() override {
+            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
+            stack.pop_back();
+            close_value();
+            return true;
+        }
+    };
+    json_error_locator err_loc;
+    auto start = it;
+    json::sax_parse(it, end, &err_loc);
+
+    if (err_loc.found_error) {
+        it = start;
+        auto temptative_end = it + err_loc.position;
+        // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
+
+        auto input = std::string(it, temptative_end);
+        try {
+            out.json = json::parse(input);
+            // out.json = json::parse(it, temptative_end);
+            it = temptative_end;
+            return true;
+        } catch (const std::exception & ex) {
+            // No, needs healing.
+            LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
+        }
+        auto can_parse = [](const std::string & str) {
+            try {
+                auto _ = json::parse(str); // NOLINT
+                return true;
+            } catch (const std::exception &) {
+                return false;
+            }
+        };
+        if (!healing_marker.empty() && !err_loc.stack.empty()) {
+            std::string str(it, temptative_end);
+            auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
+            if (last_non_sp_pos == std::string::npos) {
+                throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
+            }
+            auto last_non_sp_char = str[last_non_sp_pos];
+            // Used to detect stops on a number, which may not be complete.
+            auto was_maybe_number = [&]() {
+                if (!str.empty() && std::isspace(str.back())) {
+                    return false;
+                }
+                return std::isdigit(last_non_sp_char) ||
+                    last_non_sp_char == '.' ||
+                    last_non_sp_char == 'e' ||
+                    last_non_sp_char == 'E' ||
+                    last_non_sp_char == '-';
+            };
+
+            std::string closing;
+            for (size_t i = err_loc.stack.size(); i > 0; i--) {
+                auto & el = err_loc.stack[i - 1];
+                if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
+                    closing += "}";
+                } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
+                    closing += "]";
+                } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
+                    throw std::runtime_error("Unexpected stack element type");
+                }
+            }
+
+            const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
+
+            if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
+                // We're inside an object value
+                if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
+                    // Was about to create an object value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                } else if (can_parse(str + ": 1" + closing)) {
+                    str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
+                } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
+                    // Was about to create an object
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
+                } else if (can_parse(str + "\"" + closing)) {
+                    // Was inside an object value string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
+                    // Was inside an object value string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
+                } else {
+                    // find last :
+                    auto last_pos = str.find_last_of(':');
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
+                    }
+                    // Cutting back to opening : for object value
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
+                if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
+                    // Was about to create an array value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                } else if (can_parse(str + "\"" + closing)) {
+                    // Was inside an array value string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
+                    // Was inside an array value string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
+                } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
+                    // Had just finished a value
+                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
+                } else {
+                    auto last_pos = str.find_last_of("[,");
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
+                    }
+                    // Cutting back to last [ or , for array value
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
+                if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
+                        (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
+                    // Was about to create an object key+value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
+                } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
+                    // Was about to create an object key+value
+                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
+                } else if (can_parse(str + "\": 1" + closing)) {
+                    // Was inside an object key string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
+                    // Was inside an object key string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
+                } else {
+                    auto last_pos = str.find_last_of(':');
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
+                    }
+                    // fprintf(stderr, "Cutting back to last : for object key+value\n");
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else {
+                throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
+            }
+            // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
+            out.json = json::parse(str);
+            it = temptative_end;
+            return true;
+        }
+        // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
+        // fprintf(stderr, "Closing: TODO\n");
+        return false;
+    }
+    out.json = json::parse(it, end);
+    it = end;
+    return true;
+}
diff --git a/common/json-partial.h b/common/json-partial.h
new file mode 100644 (file)
index 0000000..854db6a
--- /dev/null
@@ -0,0 +1,37 @@
+#pragma once
+#include <json.hpp>
+
+// Healing marker (empty if the JSON was fully parsed / wasn't healed).
+struct common_healing_marker {
+    // Raw marker.
+    std::string marker;
+
+    // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
+    std::string json_dump_marker;
+};
+
+// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
+struct common_json {
+    nlohmann::ordered_json json;
+
+    common_healing_marker healing_marker;
+};
+
+// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
+//
+// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
+// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
+// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
+//
+// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
+bool common_json_parse(
+    const std::string & input,
+    const std::string & healing_marker,
+    common_json & out);
+
+// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
+bool common_json_parse(
+    std::string::const_iterator & it,
+    const std::string::const_iterator & end,
+    const std::string & healing_marker,
+    common_json & out);
index 28705e24c0b71fa6a344a92e9536c771f4d1f442..9c04d35fd00a290a0710a82788dc9c754a37e25b 100644 (file)
@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
 #endif // LLAMA_USE_LLGUIDANCE
     } else {
-        std::vector<std::string> patterns_at_start;
+        std::vector<std::string> trigger_patterns;
         std::vector<std::string> patterns_anywhere;
         std::vector<llama_token> trigger_tokens;
         for (const auto & trigger : params.grammar_triggers) {
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
-                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
                 {
-                    const auto & pattern = trigger.value;
-                    (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
+                    patterns_anywhere.push_back(trigger.value);
+                    break;
+                }
+                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
+                {
+                    trigger_patterns.push_back(trigger.value);
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
             }
         }
 
-        std::vector<std::string> trigger_patterns;
-        if (!patterns_at_start.empty()) {
-            trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
-        }
         if (!patterns_anywhere.empty()) {
             trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
         }
index c3873c3fa63d1d2603fd22245b40f99d6ecbc67e..4a72e843ea9e0fcf433815a00bb04ac833549ec6 100644 (file)
@@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
 > [!TIP]
 > If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
 
+> [!CAUTION]
+> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
+
 Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
 
 ```bash
 curl http://localhost:8080/v1/chat/completions -d '{
-"model": "gpt-3.5-turbo",
-"tools": [
-    {
-    "type":"function",
-    "function":{
-        "name":"python",
-        "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
-        "parameters":{
-        "type":"object",
-        "properties":{
-            "code":{
-            "type":"string",
-            "description":"The code to run in the ipython interpreter."
+    "model": "gpt-3.5-turbo",
+    "tools": [
+        {
+        "type":"function",
+        "function":{
+            "name":"python",
+            "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
+            "parameters":{
+            "type":"object",
+            "properties":{
+                "code":{
+                "type":"string",
+                "description":"The code to run in the ipython interpreter."
+                }
+            },
+            "required":["code"]
             }
-        },
-        "required":["code"]
         }
-    }
-    }
-],
-"messages": [
-    {
-    "role": "user",
-    "content": "Print a hello world message with python."
-    }
-]
+        }
+    ],
+    "messages": [
+        {
+        "role": "user",
+        "content": "Print a hello world message with python."
+        }
+    ]
+}'
+
+
+curl http://localhost:8080/v1/chat/completions -d '{
+    "model": "gpt-3.5-turbo",
+    "messages": [
+        {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
+        {"role": "user", "content": "What is the weather in Istanbul?"}
+    ],
+    "tools": [{
+        "type":"function",
+        "function":{
+            "name":"get_current_weather",
+            "description":"Get the current weather in a given location",
+            "parameters":{
+                "type":"object",
+                "properties":{
+                    "location":{
+                        "type":"string",
+                        "description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
+                    }
+                },
+                "required":["location"]
+            }
+        }
+    }]
 }'
 ```
 
diff --git a/models/templates/Qwen-QwQ-32B.jinja b/models/templates/Qwen-QwQ-32B.jinja
new file mode 100644 (file)
index 0000000..d475f70
--- /dev/null
@@ -0,0 +1,62 @@
+{%- if tools %}
+    {{- '<|im_start|>system\n' }}
+    {%- if messages[0]['role'] == 'system' %}
+        {{- messages[0]['content'] }}
+    {%- else %}
+        {{- '' }}
+    {%- endif %}
+    {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
+    {%- for tool in tools %}
+        {{- "\n" }}
+        {{- tool | tojson }}
+    {%- endfor %}
+    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
+{%- else %}
+    {%- if messages[0]['role'] == 'system' %}
+        {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
+  {%- endif %}
+{%- endif %}
+{%- for message in messages %}
+    {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+        {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
+    {%- elif message.role == "assistant" and not message.tool_calls %}
+        {%- set content = message.content %}
+        {%- if not loop.last %}
+            {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
+        {%- endif %}
+        {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
+    {%- elif message.role == "assistant" %}
+        {%- set content = message.content %}
+        {%- if not loop.last %}
+            {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
+        {%- endif %}
+        {{- '<|im_start|>' + message.role }}
+        {%- if message.content %}
+            {{- '\n' + content }}
+        {%- endif %}
+        {%- for tool_call in message.tool_calls %}
+            {%- if tool_call.function is defined %}
+                {%- set tool_call = tool_call.function %}
+            {%- endif %}
+            {{- '\n<tool_call>\n{"name": "' }}
+            {{- tool_call.name }}
+            {{- '", "arguments": ' }}
+            {{- tool_call.arguments | tojson }}
+            {{- '}\n</tool_call>' }}
+        {%- endfor %}
+        {{- '<|im_end|>\n' }}
+    {%- elif message.role == "tool" %}
+        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
+            {{- '<|im_start|>user' }}
+        {%- endif %}
+        {{- '\n<tool_response>\n' }}
+        {{- message.content }}
+        {{- '\n</tool_response>' }}
+        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+            {{- '<|im_end|>\n' }}
+        {%- endif %}
+    {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+    {{- '<|im_start|>assistant\n<think>\n' }}
+{%- endif %}
index e4fd104fc9fe6d0c17293cfa98a0cdf9c44f45af..b8655be9fce95dfe26cb83eb8d0128d42c713270 100644 (file)
@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
 ./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
 ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use   > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
 ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct                      > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
+./scripts/get_chat_template.py Qwen/QwQ-32B                                  > models/templates/Qwen-QwQ-32B.jinja
 ```
\ No newline at end of file
index a2f2a2eb020048700431aceacd0d80bdf82a8cca..d8018e2e23c0dfe6360a08c4b36c01cbd65a1790 100755 (executable)
@@ -12,6 +12,7 @@
         export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
         export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
 
+        ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
         ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M"      --output qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF      --ollama qwen2.5:1.5b-instruct-q4_K_M
         ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M"  --output qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF  --ollama qwen2.5-coder:7b
 
@@ -205,6 +206,7 @@ def run(
     model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
     hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
     chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
+    chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
     ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
     llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
     n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
@@ -229,6 +231,12 @@ def run(
     # n_ctx = 8192
     n_ctx = 2048
 
+    if model is None:
+        if hf is not None:
+            model = hf.split("/")[-1]
+        elif ollama is not None:
+            model = ollama
+
     assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
 
     with output.open('a' if append else 'w') as output_file:
@@ -320,6 +328,7 @@ def run(
                     server.model_hf_repo = hf
                     server.model_hf_file = None
                     server.chat_template = chat_template
+                    server.chat_template_file = chat_template_file
                     server.server_path = server_path
                     if port is not None:
                         server.server_port = port
@@ -335,6 +344,7 @@ def run(
                                 temp=t,
                                 output_kwargs=dict(
                                     chat_template=chat_template,
+                                    chat_template_file=chat_template_file,
                                 ),
                                 request_kwargs=dict(
                                     ignore_chat_grammar=ignore_chat_grammar,
@@ -355,6 +365,7 @@ def run(
                         temp=t,
                         output_kwargs=dict(
                             chat_template=None,
+                            chat_template_file=None,
                         ),
                         request_kwargs=dict(
                             model=ollama,
index 973b47ae063b08a6faec73294a6ab05d78e4ac56..bed706bb248d139664d8024726948e9fb1ba4cb5 100644 (file)
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             for (const auto & trigger_pattern : grammar.trigger_patterns) {
                 if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
                     grammar.awaiting_trigger = false;
-                    // get from the first match to the end of the string
-                    auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
+                    // get from the first matched capturing group to the end of the string
+                    size_t start = std::string::npos;
+                    for (auto i = 1u; i < match.size(); i++) {
+                        if (match.length(i) > 0) {
+                            start = match.position(i);
+                            break;
+                        }
+                    }
+                    if (start == std::string::npos) {
+                        start = match.position(0);
+                    }
+                    auto constrained_str = grammar.trigger_buffer.substr(start);
                     // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
                     grammar.trigger_buffer.clear();
                     llama_grammar_accept_str(grammar, constrained_str);
index 083347d188880f0aea3c85592e3f4e808fd49f74..00466b9ba02ec3604d19de6cc6ec020a7845eff9 100644 (file)
@@ -142,8 +142,10 @@ if (NOT WIN32)
     # llama_build_and_test(test-double-float.cpp) # SLOW
 endif()
 
-llama_build_and_test(test-log.cpp)
+llama_build_and_test(test-chat-parser.cpp)
 llama_build_and_test(test-chat-template.cpp)
+llama_build_and_test(test-json-partial.cpp)
+llama_build_and_test(test-log.cpp)
 llama_build_and_test(test-regex-partial.cpp)
 
 # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp
new file mode 100644 (file)
index 0000000..2113a12
--- /dev/null
@@ -0,0 +1,355 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include <exception>
+#include <iostream>
+#include <json.hpp>
+#include <string>
+
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+using json = nlohmann::ordered_json;
+
+template <class T>
+static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+static void assert_equals(const char * expected, const std::string & actual) {
+  return assert_equals<std::string>(expected, actual);
+}
+
+static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
+    try {
+        fn();
+    } catch (const std::exception & e) {
+      if (expected_exception_pattern.empty()) {
+          return;
+        }
+        std::regex expected_exception_regex(expected_exception_pattern);
+        std::string actual_message = e.what();
+        if (std::regex_search(actual_message, expected_exception_regex)) {
+            return;
+        }
+        throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
+        throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
+    }
+    throw std::runtime_error("Exception was expected but not thrown");
+}
+
+static void test_reasoning() {
+  {
+    common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ true,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("<think>Cogito</think>", builder.result().content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+}
+
+static void test_regex() {
+  auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
+    common_chat_msg_parser builder(input, /* is_partial= */ false, {});
+    assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
+  };
+
+  test_throws("Hello, world!", "abc", "^abc$");
+  test_throws("Hello, world!", "e", "^e$");
+
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    builder.consume_regex(common_regex("Hello"));
+    assert_equals(", world!", builder.consume_rest());
+  }
+
+  {
+    // When in non partial mode, we can say whether the regex was consumed or not.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
+  }
+  {
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
+    assert_equals(true, res.has_value());
+    // Verify captures
+    assert_equals<size_t>(2, res->groups.size());
+    assert_equals("Hell", builder.str(res->groups[0]));
+    assert_equals("el", builder.str(res->groups[1]));
+    // Verify position is after the match
+    assert_equals<size_t>(4, builder.pos());
+    assert_equals("o,", builder.consume_rest());
+  }
+  {
+    // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
+    assert_throws([&]() {
+      builder.try_consume_regex(common_regex("Hello, world!"));
+    }, "^Hello, world!$");
+  }
+
+  // Now regardless of the mode, we can tell these aren't a match.
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
+  }
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_literal("Oh"));
+  }
+}
+
+const std::vector<std::string> barely_healable_jsons = {
+  "{",
+  "{\"",
+  "{\"\\",
+  "{\"n",
+  "{\"name\"",
+  "{\"name\":",
+  "{\"name\":\"",
+  "{\"name\":\"\\",
+  "{\"name\":\"python",
+  "{\"name\":\"python\\",
+  "{\",",
+  "{\":",
+  "{\"[",
+  "{\"]",
+  "{\"{",
+  "{\"}",
+  "{\"1",
+  "{\"name\":\",",
+  "{\"name\":\":",
+  "{\"name\":\"[",
+  "{\"name\":\"]",
+  "{\"name\":\"{",
+  "{\"name\":\"}",
+  "{\"name\":\"1",
+};
+
+static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
+  common_chat_msg_parser builder(input, is_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
+}
+static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
+  common_chat_msg_parser builder(input, parse_as_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, js->value.dump());
+}
+
+static void test_json_with_dumped_args_no_args() {
+  // Normal JSON, nothing to heal, nothing to dump
+  test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
+  // Full json is args
+  test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
+
+  // If the arguments are further down, don't heal partial content.
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{"arguments"}}, {}, "{}");
+  }
+  // But heal content that isn't partial.
+  test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
+}
+
+static void test_json_with_dumped_args() {
+
+  // Partial content.
+  test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
+  test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
+  test("{\"content\": ", true, {}, {{"content"}}, "{}");
+
+  // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
+  test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{}}, {}, src);
+  }
+
+  // Full JSON w/ args
+  for (auto parse_as_partial : {true, false}) {
+    test_with_args(
+      R"({"name": "python", "args": {"arg1": 1}})",
+      R"({"name":"python","args":"{\"arg1\":1}"})",
+      parse_as_partial,
+      /* is_partial= */ false
+    );
+  }
+
+  // Partial JSON w/ partial args
+  test_with_args(
+    R"({"foo": "bar", "args": {")",
+    R"({"foo":"bar","args":"{\""})"
+  );
+  // Partial args broken in object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"ar)",
+    R"({"foo":"bar","args":"{\"ar"})"
+  );
+  // Partial args broken after object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1")",
+    R"({"foo":"bar","args":"{\"arg1\""})"
+  );
+  // Partial args broken before object value
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1":)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken before object value (space)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": )",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that may not be complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1 )",
+    R"({"foo":"bar","args":"{\"arg1\":1"})"
+  );
+  // Partial args broken in object value that is incomplete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": ")",
+    R"({"foo":"bar","args":"{\"arg1\":\""})"
+  );
+  // Partial args broken in object value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": "1")",
+    R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
+  );
+  // Partial args broken on array opening
+  test_with_args(
+    R"({"foo": "bar", "args": [)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is incomplete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1 )",
+    R"({"foo":"bar","args":"[1"})"
+  );
+  // Partial args broken on array value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": ["1")",
+    R"({"foo":"bar","args":"[\"1\""})"
+  );
+  // Partial args broken after array value
+  test_with_args(
+    R"({"foo": "bar", "args": [1,)",
+    R"({"foo":"bar","args":"[1,"})"
+  );
+  // Partial args broken on nested array
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": [)",
+    R"({"foo":"bar","args":"{\"arg1\":["})"
+  );
+}
+
+static void test_positions() {
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    assert_equals<size_t>(0, builder.pos());
+    assert_throws([&]() { builder.move_to(100); });
+    assert_equals<size_t>(0, builder.pos());
+    assert_throws([&]() { builder.move_back(1); });
+    assert_equals<size_t>(0, builder.pos());
+
+    builder.move_to(8);
+    assert_equals<size_t>(8, builder.pos());
+    builder.move_back(1);
+    assert_equals<size_t>(7, builder.pos());
+    assert_equals("world!", builder.consume_rest());
+
+    builder.move_to(0);
+    assert_equals<size_t>(0, builder.pos());
+
+    assert_throws([&]() { builder.finish(); });
+    assert_equals<size_t>(0, builder.pos());
+
+    builder.move_to(builder.input().size());
+    builder.finish();
+  }
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
+
+    builder.move_to(builder.input().size());
+    assert_equals<size_t>(builder.input().size(), builder.pos());
+    builder.finish();
+  }
+}
+
+int main() {
+    test_positions();
+    test_json_with_dumped_args_no_args();
+    test_json_with_dumped_args();
+    test_reasoning();
+    test_regex();
+    std::cout << "All tests passed!\n";
+    return 0;
+}
index 4d70da8c32c91024f5fb52ace7b40f99249dec9e..dfcdce350ba86611d7bb44e4a18a946265a71a3b 100644 (file)
 
 using json = nlohmann::ordered_json;
 
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) {
+    // os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n';
+    os << "{ content_delta: " << diff.content_delta << "; ";
+    if (diff.tool_call_index != std::string::npos) {
+        os << "tool_call_index: " << diff.tool_call_index << "; ";
+        os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; ";
+        os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; ";
+        os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; ";
+    }
+    os << "}";
+    return os;
+}
+// operator<< for vector<common_chat_msg_diff>:
+static std::ostream & operator<<(std::ostream & os, const std::vector<common_chat_msg_diff> & diffs) {
+    os << "[\n";
+    for (const auto & diff : diffs) {
+        os << "  " << diff << ",\n";
+    }
+    os << "]";
+    return os;
+}
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg & msg) {
+    os << "{ role: " << msg.role << "; ";
+    os << "content: " << msg.content << "; ";
+    os << "content_parts: [\n";
+    for (const auto & part : msg.content_parts) {
+        os << "  { type: " << part.type << "; text: " << part.text << " },\n";
+    }
+    os << "]; ";
+    os << "reasoning_content: " << msg.reasoning_content << "; ";
+    os << "tool_calls: [\n";
+    for (const auto & tool_call : msg.tool_calls) {
+        os << "  { name: " << tool_call.name << "; arguments: " << tool_call.arguments << "; id: " << tool_call.id << " },\n";
+    }
+    os << "]";
+    os << "}";
+    return os;
+}
+
+template <class T> static bool equals(const T & expected, const T & actual) {
+    return expected == actual;
+}
+
+static common_chat_msg normalize(const common_chat_msg & msg) {
+    common_chat_msg normalized = msg;
+    for (auto & tool_call : normalized.tool_calls) {
+        try {
+            tool_call.arguments = json::parse(tool_call.arguments).dump();
+        } catch (const std::exception &) {
+            // Do nothing
+        }
+    }
+    return normalized;
+}
+template <>
+bool equals(const common_chat_msg & expected, const common_chat_msg & actual) {
+    return normalize(expected) == normalize(actual);
+}
 
 template <class T> static void assert_equals(const T & expected, const T & actual) {
-    if (expected != actual) {
+    if (!equals(expected, actual)) {
         std::cerr << "Expected: " << expected << std::endl;
         std::cerr << "Actual: " << actual << std::endl;
         std::cerr << std::flush;
@@ -77,6 +135,15 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
     return false;
 }
 
+static std::string renormalize_json(const std::string & json_str) {
+    try {
+        auto json_obj = json::parse(json_str);
+        return json_obj.dump();
+    } catch (const std::exception & e) {
+        std::cerr << "Failed to parse JSON: " << e.what() << '\n';
+        return json_str;
+    }
+}
 static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
     assert_equals(expected.role, actual.role);
     assert_equals(expected.content, actual.content);
@@ -93,7 +160,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
         const auto & expected_tool_call = expected.tool_calls[i];
         const auto & actual_tool_call   = actual.tool_calls[i];
         assert_equals(expected_tool_call.name, actual_tool_call.name);
-        assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
+        assert_equals(renormalize_json(expected_tool_call.arguments), renormalize_json(actual_tool_call.arguments));
         assert_equals(expected_tool_call.id, actual_tool_call.id);
     }
 }
@@ -152,14 +219,12 @@ static delta_data init_delta(const struct common_chat_templates * tmpls, const s
                              const common_chat_msg & user_message,
                              const common_chat_msg & delta_message,
                              const std::vector<common_chat_tool> & tools,
-                             const common_chat_tool_choice & tool_choice,
-                             bool think = false) {
+                             const common_chat_tool_choice & tool_choice) {
     common_chat_templates_inputs inputs;
     inputs.parallel_tool_calls = true;
     inputs.messages.push_back(user_message);
     inputs.tools       = tools;
     inputs.tool_choice = tool_choice;
-    inputs.extract_reasoning = think;
     auto params_prefix = common_chat_templates_apply(tmpls, inputs);
 
     inputs.messages.push_back(delta_message);
@@ -211,19 +276,22 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
                           const std::string & expected_delta = "",
                           bool expect_grammar_triggered = true,
                           bool test_grammar_if_triggered = true,
-                          bool think = false) {
+                          common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) {
     common_chat_msg user_message;
     user_message.role = "user";
     user_message.content = "Hello, world!";
 
     for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
-        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
+        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
         if (!expected_delta.empty()) {
             assert_equals(expected_delta, data.delta);
         }
 
         if (expect_grammar_triggered) {
-            const auto msg = common_chat_parse(data.delta, data.params.format);
+            common_chat_syntax syntax;
+            syntax.format = data.params.format;
+            syntax.reasoning_format = reasoning_format;
+            const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax);
             assert_msg_equals(test_message, msg);
         }
 
@@ -251,15 +319,25 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
                     {
                         const auto & pattern = trigger.value;
                         if (std::regex_search(constrained, match, std::regex(pattern))) {
-                            pos = match.position();
+                            pos = match.position(1);
                         }
                         break;
                     }
-                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
                     {
                         const auto & pattern = trigger.value;
-                        if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
-                            pos = 0;
+                        if (std::regex_match(constrained, match, std::regex(pattern))) {
+                            auto mpos = std::string::npos;
+                            for (size_t i = 1; i < match.size(); ++i) {
+                                if (match[i].length() > 0) {
+                                    mpos = match.position(i);
+                                    break;
+                                }
+                            }
+                            if (mpos == std::string::npos) {
+                                mpos = match.position(0);
+                            }
+                            pos = mpos;
                         }
                         break;
                     }
@@ -313,117 +391,39 @@ const common_chat_msg message_user_parts {
     /* .tool_name = */ "",
     /* .tool_call_id = */ "",
 };
-const common_chat_msg message_assist {
-    "assistant",
-    "Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts_unparsed_think {
-    "assistant",
-    "<think>I'm thinking</think>Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts_unparsed_r7b {
-    "assistant",
-    "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts {
-    "assistant",
-    "Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "I'm thinking",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const std::vector<common_chat_tool_call> tool_calls {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
-};
-const std::vector<common_chat_tool_call> tool_calls_idx {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
-};
-const std::vector<common_chat_tool_call> tool_calls_id {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
-};
-
-const common_chat_msg message_assist_call {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_thoughts = {
-    "assistant",
-    /* .content = */ "",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "I'm\nthinking",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_thoughts_unparsed = {
-    "assistant",
-    /* .content = */ "<think>I'm\nthinking</think>",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_id {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls_id,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_idx {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls_idx,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_python {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_code_interpreter {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
+static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") {
+    common_chat_msg msg;
+    msg.role = "assistant";
+    msg.content = content;
+    msg.reasoning_content = reasoning_content;
+    if (!tool_name.empty()) {
+        msg.tool_calls.push_back({ tool_name, arguments, id });
+    }
+    return msg;
+}
+const common_chat_msg message_assist                             = simple_assist_msg("Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_empty                       = simple_assist_msg("");
+const common_chat_msg message_assist_thoughts_unparsed_deepseek  = simple_assist_msg("<think>I'm\nthinking</think>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_unparsed_r7b       = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts                    = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
+const common_chat_msg message_assist_thoughts_unopened_unparsed  = simple_assist_msg("I'm\nthinking</think>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_no_content         = simple_assist_msg("", "I'm\nthinking");
+const common_chat_msg message_assist_call                        = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_content                = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_empty_args             = simple_assist_msg("", "", "special_function");
+const common_chat_msg message_assist_call_cutoff_args            = simple_assist_msg("", "", "special_function", "{\"arg");
+const common_chat_msg message_assist_call_thoughts               = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_thoughts_unparsed      = simple_assist_msg("<think>I'm\nthinking</think>\n\n", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_id                     = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
+const common_chat_msg message_assist_call_idx                    = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
+const common_chat_msg message_assist_thoughts_call_idx           = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
+const common_chat_msg message_assist_call_python                 = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}");
+const common_chat_msg message_assist_call_python_lines           = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}");
+const common_chat_msg message_assist_call_python_lines_unclosed  = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')");
+const common_chat_msg message_assist_call_code_interpreter       = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}");
 
 static void test_msgs_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
     std::vector<common_chat_msg> msgs{
         message_user,
         message_user_parts,
@@ -473,7 +473,7 @@ static void test_msgs_oaicompat_json_conversion() {
             "        \"type\": \"function\",\n"
             "        \"function\": {\n"
             "          \"name\": \"python\",\n"
-            "          \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
+            "          \"arguments\": \"{\\\"code\\\":\\\"print('hey')\\\"}\"\n"
             "        }\n"
             "      }\n"
             "    ]\n"
@@ -499,6 +499,7 @@ static void test_msgs_oaicompat_json_conversion() {
 }
 
 static void test_tools_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
     std::vector<common_chat_tool> tools{
         special_function_tool,
         python_tool,
@@ -543,29 +544,18 @@ static void test_tools_oaicompat_json_conversion() {
 }
 
 static void test_template_output_parsers() {
+    printf("[%s]\n", __func__);
 
     common_chat_templates_inputs inputs_no_tools;
     inputs_no_tools.messages                = {message_user};
-    inputs_no_tools.extract_reasoning       = false;
-
-    common_chat_templates_inputs inputs_no_tools_think;
-    inputs_no_tools_think.messages          = {message_user};
-    inputs_no_tools_think.extract_reasoning = true;
 
     common_chat_templates_inputs inputs_tools;
     inputs_tools.messages                   = {message_user};
     inputs_tools.tools                      = {special_function_tool};
-    inputs_tools.extract_reasoning          = false;
-
-    common_chat_templates_inputs inputs_tools_think;
-    inputs_tools_think.messages             = {message_user};
-    inputs_tools_think.tools                = {special_function_tool};
-    inputs_tools_think.extract_reasoning    = true;
 
     common_chat_templates_inputs inputs_tools_builtin;
     inputs_tools_builtin.messages           = {message_user};
     inputs_tools_builtin.tools              = {python_tool};
-    inputs_tools_builtin.extract_reasoning  = false;
 
     {
         // Not supported yet
@@ -577,44 +567,95 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
         std::vector<std::string>   end_tokens{ "<|END_OF_TURN_TOKEN|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format);
+            assert_equals(false, params.thinking_forced_open);
+        }
 
         assert_msg_equals(message_assist,
             common_chat_parse(
                 "Hello, world!\nWhat's up?",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
         assert_msg_equals(message_assist,
             common_chat_parse(
-                "Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(message_assist,
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
+        assert_msg_equals(message_assist_thoughts,
             common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
             common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                }));
         assert_msg_equals(message_assist_thoughts_unparsed_r7b,
             common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
-                "Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
         assert_msg_equals(message_assist_thoughts,
             common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_call_idx,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
+                "]<|END_ACTION|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_no_content,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
 
         test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
                       "<|START_THINKING|><|END_THINKING|>"
                       "<|START_ACTION|>[\n"
                       "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
-                      "]<|END_ACTION|>");
+                      "]<|END_ACTION|>",
+                      /* expect_grammar_triggered= */ true,
+                      /* test_grammar_if_triggered= */ true,
+                      COMMON_REASONING_FORMAT_DEEPSEEK);
         test_templates(tmpls.get(), end_tokens, message_assist, tools,
                       "<|START_RESPONSE|>Hello, world!\n"
                       "What's up?<|END_RESPONSE|>",
@@ -634,11 +675,40 @@ static void test_template_output_parsers() {
 
         // Generic tool calls doesn't generate / parse content-only messages symmetrically.
 
+        assert_equals(
+            message_assist_empty,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"t",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","),
+            common_chat_parse(
+                R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            message_assist_call_empty_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+        assert_equals(
+            message_assist_call_cutoff_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
         assert_msg_equals(message_assist,
-                          common_chat_parse("{\n"
-                                            "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
-                                            "}",
-                                            common_chat_templates_apply(tmpls.get(), inputs_tools).format));
+            common_chat_parse(
+                "{\n"
+                "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
+                "}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_GENERIC}));
         test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
                       "{\n"
                       "  \"tool_calls\": [\n"
@@ -663,6 +733,13 @@ static void test_template_output_parsers() {
             tmpls.get(), end_tokens, message_assist_call_id, tools,
             "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
     }
+    {
+        auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
+        std::vector<std::string> end_tokens{ "<|im_end|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+    }
     {
         auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
         std::vector<std::string> end_tokens{ "<|im_end|>" };
@@ -683,113 +760,257 @@ static void test_template_output_parsers() {
                 .format);
 
         // Test parsing
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<tool_call>\n"
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</tool_call>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<function=special_function>{\"arg1\": 1}</function>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<function name=\"special_function\">\n"
-            "{\"arg1\": 1}\n"
-            "</function>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<tool>\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</tool>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<tools>\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</tools>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<response>\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</response>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```xml\n"
-            "<response>\n"
-            "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</response>\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```xml\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```\n"
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```json\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```json\n"
-            "\n"
-            "                    <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
-            "                    </function_call> \n"
-            "``` ",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<json>\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</json>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<xml>\n"
-            "  {\n"
-            "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
-            "  }\n"
-            "</xml>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "<JSON>\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "</JSON>",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(
+            simple_assist_msg("", "", "python", ""),
+            common_chat_parse(
+                "```json\n"
+                "<function_call> { \"name\" : \"python\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            simple_assist_msg("Let's call something\n"),
+            common_chat_parse(
+                "Let's call something\n"
+                "<tool_call>{\"name\"",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(
+            simple_assist_msg(""),
+            common_chat_parse(
+                "Let's call something\n"
+                "<tool_call>{\"name",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_call_thoughts,
+            common_chat_parse(
+                // QwQ-32B's template adds a trailing <think> if add_generation_prompt
+                "I'm\nthinking</think>\n"
+                "<tool_call>{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<tool_call>\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</tool_call>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_call_content,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?<tool_call>\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</tool_call>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<function=special_function>{\"arg1\": 1}</function>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<function name=\"special_function\">\n"
+                "{\"arg1\": 1}\n"
+                "</function>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<tool>\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</tool>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<tools>\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</tools>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<response>\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</response>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "<response>\n"
+                "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</response>\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "\n"
+                "                    <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
+                "                    </function_call> \n"
+                "``` ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<json>\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</json>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<xml>\n"
+                "  {\n"
+                "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
+                "  }\n"
+                "</xml>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "<JSON>\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "</JSON>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+
+        assert_msg_equals(
+            simple_assist_msg(
+                "This is not a tool call:",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "This is not a tool call:\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        // assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+        //     common_chat_parse(
+        //         "I'm\nthinking</think>Hello, world!\nWhat's up?",
+        //         COMMON_CHAT_FORMAT_HERMES_2_PRO));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<tool_call>\n"
                       "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
                       "</tool_call>");
-        test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
                       "<tool_call>\n"
-                      "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
+                      "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"
                       "</tool_call>");
     }
     {
@@ -806,6 +1027,13 @@ static void test_template_output_parsers() {
                           inputs_tools_builtin)
                           .format);
 
+        assert_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_LLAMA_3_X}));
+
         // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
                       "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
@@ -836,6 +1064,15 @@ static void test_template_output_parsers() {
         assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
                         common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
 
+        for (auto is_partial : { false, true }) {
+            assert_equals(
+                message_assist_call,
+                common_chat_parse(
+                    "<function=special_function>{\"arg1\": 1}</function>",
+                    is_partial,
+                    {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
+        }
+
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<function=special_function>{\"arg1\": 1}</function>");
@@ -847,6 +1084,47 @@ static void test_template_output_parsers() {
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
+        assert_msg_equals(
+            simple_assist_msg(
+                "Hello, world!\nnono\nWhat's up?",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\n"
+                "nono\n"
+                "What's up?>>>special_function\n"
+                "{\"arg1\": 1}\n",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines_unclosed,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "special_function\n"
+                "{\"arg1\": 1} \n                    ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+
         test_templates(tmpls.get(), end_tokens, message_assist, {},
                       "all\n"
                       "Hello, world!\n"
@@ -872,22 +1150,77 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format);
+            assert_equals(true, params.thinking_forced_open);
+        }
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+        assert_msg_equals(
+            simple_assist_msg("Hello, world!\nWhat's up?", "<think>I'm\nthinking"),
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"),
+            common_chat_parse(
+                "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with",
+                /* is_partial= */ true,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
         assert_msg_equals(message_assist_thoughts,
             // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
-            common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
         // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
         //               "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
         //               "```json\n"
@@ -904,16 +1237,34 @@ static void test_template_output_parsers() {
 
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinking</think>Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
 
         assert_msg_equals(message_assist_call_thoughts_unparsed,
             common_chat_parse(
@@ -922,7 +1273,17 @@ static void test_template_output_parsers() {
                 "```json\n"
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
-                COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "<|tool▁calls|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+
         assert_msg_equals(message_assist_call_thoughts,
             common_chat_parse(
                 "<think>I'm\nthinking</think>\n\n"
@@ -930,7 +1291,13 @@ static void test_template_output_parsers() {
                 "```json\n"
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
-                COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ false,
+                }));
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                 "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
                 "```json\n"
@@ -939,6 +1306,91 @@ static void test_template_output_parsers() {
     }
 }
 
+static void test_msg_diffs_compute() {
+    printf("[%s]\n", __func__);
+    {
+        common_chat_msg msg1;
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = "Hello, world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg1;
+        msg1.content = "Hello,";
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = " world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg1;
+        msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } };
+
+        common_chat_msg msg2;
+        msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } };
+
+        common_chat_msg_diff diff01;
+        diff01.tool_call_index = 0;
+        diff01.tool_call_delta.name = "special_function";
+        diff01.tool_call_delta.id = "123";
+        diff01.tool_call_delta.arguments = "{\"ar";
+
+        assert_equals(
+            {diff01},
+            common_chat_msg_diff::compute_diffs(msg0, msg1));
+
+        common_chat_msg_diff diff12;
+        diff12.tool_call_index = 0;
+        diff12.tool_call_delta.name = "special_function";
+        // Note: id doesnt change here.
+        diff12.tool_call_delta.arguments = "g1\": 1}";
+
+        assert_equals(
+            {diff12},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg2;
+        msg2.tool_calls = {
+            { "f1", "{\"arg1\": 1}", /* .id = */ "123" },
+            { "f2", "{\"arg2\": 2}", /* .id = */ "222" },
+        };
+
+        common_chat_msg_diff diff1;
+        diff1.tool_call_index = 0;
+        diff1.tool_call_delta.name = "f1";
+        diff1.tool_call_delta.id = "123";
+        diff1.tool_call_delta.arguments = "{\"arg1\": 1}";
+
+        common_chat_msg_diff diff2;
+        diff2.tool_call_index = 1;
+        diff2.tool_call_delta.name = "f2";
+        diff2.tool_call_delta.id = "222";
+        diff2.tool_call_delta.arguments = "{\"arg2\": 2}";
+
+        assert_equals(
+            {diff1, diff2},
+            common_chat_msg_diff::compute_diffs(msg0, msg2));
+    }
+}
+
 int main(int argc, char ** argv) {
     // try {
 #ifndef _WIN32
@@ -972,6 +1424,7 @@ int main(int argc, char ** argv) {
         } else
 #endif
         {
+            test_msg_diffs_compute();
             test_msgs_oaicompat_json_conversion();
             test_tools_oaicompat_json_conversion();
             test_template_output_parsers();
diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp
new file mode 100644 (file)
index 0000000..bc136be
--- /dev/null
@@ -0,0 +1,237 @@
+#include "common.h"
+#include "json-partial.h"
+#include <exception>
+#include <iostream>
+#include <stdexcept>
+
+template <class T> static void assert_equals(const T & expected, const T & actual) {
+  if (expected != actual) {
+      std::cerr << "Expected: " << expected << std::endl;
+      std::cerr << "Actual: " << actual << std::endl;
+      std::cerr << std::flush;
+      throw std::runtime_error("Test failed");
+  }
+}
+
+static void test_json_healing() {
+  auto parse = [](const std::string & str) {
+      std::cerr << "# Parsing: " << str << '\n';
+      std::string::const_iterator it = str.begin();
+      const auto end = str.end();
+      common_json out;
+      std::string healing_marker = "$llama.cpp.json$";
+      if (common_json_parse(it, end, healing_marker, out)) {
+          auto dump = out.json.dump();
+          std::cerr << "Parsed: " << dump << '\n';
+          std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
+          std::string result;
+          if (!out.healing_marker.json_dump_marker.empty()) {
+              auto i = dump.find(out.healing_marker.json_dump_marker);
+              if (i == std::string::npos) {
+                  throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
+              }
+              result = dump.substr(0, i);
+          } else {
+            result = dump;
+          }
+          std::cerr << "Result: " << result << '\n';
+          if (string_starts_with(str, result)) {
+            std::cerr << "Failure!\n";
+          }
+        //   return dump;
+      } else {
+        throw std::runtime_error("Failed to parse: " + str);
+      }
+
+  };
+  auto parse_all = [&](const std::string & str) {
+      for (size_t i = 1; i < str.size(); i++) {
+          parse(str.substr(0, i));
+      }
+  };
+  parse_all("{\"a\": \"b\"}");
+  parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
+
+  parse_all("[{\"a\": \"b\"}]");
+
+  auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
+      for (const auto & input : inputs) {
+        common_json out;
+        assert_equals(true, common_json_parse(input, "$foo", out));
+        assert_equals<std::string>(expected, out.json.dump());
+        assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
+      }
+  };
+  // No healing needed:
+  test(
+    {
+      R"([{"a":"b"}, "y"])",
+    },
+    R"([{"a":"b"},"y"])",
+    ""
+  );
+  // Partial literals can't be healed:
+  test(
+    {
+      R"([1)",
+      R"([tru)",
+      R"([n)",
+      R"([nul)",
+      R"([23.2)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({"a": 1)",
+      R"({"a": tru)",
+      R"({"a": n)",
+      R"({"a": nul)",
+      R"({"a": 23.2)",
+    },
+    R"({"a":"$foo"})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({)",
+    },
+    R"({"$foo":1})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  // Healing right after a full literal
+  test(
+    {
+      R"(1 )",
+    },
+    R"(1)",
+    ""
+  );
+  test(
+    {
+      R"(true)",
+      R"(true )",
+    },
+    R"(true)",
+    ""
+  );
+  test(
+    {
+      R"(null)",
+      R"(null )",
+    },
+    R"(null)",
+    ""
+  );
+  test(
+    {
+      R"([1 )",
+    },
+    R"([1,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{})",
+      R"([{} )",
+    },
+    R"([{},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true)",
+    },
+    // TODO: detect the true/false/null literal was complete
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([true )",
+    },
+    R"([true,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true,)",
+    },
+    R"([true,"$foo"])",
+    R"("$foo)"
+  );
+  // Test nesting
+  test(
+    {
+      R"([{"a": [{"b": [{)",
+    },
+    R"([{"a":[{"b":[{"$foo":1}]}]}])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([{"a": [{"b": [)",
+    },
+    R"([{"a":[{"b":["$foo"]}]}])",
+    R"("$foo)"
+  );
+
+  test(
+    {
+      R"([{"a": "b"})",
+      R"([{"a": "b"} )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{"a": "b"},)",
+      R"([{"a": "b"}, )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({ "code)",
+    },
+    R"({"code$foo":1})",
+    R"($foo)"
+  );
+  test(
+    {
+      R"({ "code\)",
+    },
+    R"({"code\\$foo":1})",
+    R"(\$foo)"
+  );
+  test(
+    {
+      R"({ "code")",
+    },
+    R"({"code":"$foo"})",
+    R"(:"$foo)"
+  );
+  test(
+    {
+      R"({ "key")",
+    },
+    R"({"key":"$foo"})",
+    R"(:"$foo)"
+  );
+}
+
+int main() {
+    test_json_healing();
+    std::cerr << "All tests passed.\n";
+    return 0;
+}
index 01afeafa0ff57bd3615ebd984d38fa92ec09b31d..9f0b0ffaa6e1ec9bb37030c89e2dbfd28a3f9c97 100644 (file)
@@ -1,3 +1,4 @@
+#include "chat.h"
 #include "utils.hpp"
 
 #include "arg.h"
@@ -114,11 +115,11 @@ struct slot_params {
     struct common_params_speculative speculative;
 
     // OAI-compat fields
-    bool                  verbose                   = false;
-    oaicompat_type        oaicompat                 = OAICOMPAT_TYPE_NONE;
-    std::string           oaicompat_model;
-    std::string           oaicompat_cmpl_id;
-    common_chat_format    oaicompat_chat_format     = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    bool                         verbose                   = false;
+    oaicompat_type               oaicompat                 = OAICOMPAT_TYPE_NONE;
+    std::string                  oaicompat_model;
+    std::string                  oaicompat_cmpl_id;
+    common_chat_syntax           oaicompat_chat_syntax;
 
     json to_json() const {
         std::vector<std::string> samplers;
@@ -176,7 +177,10 @@ struct slot_params {
             {"grammar_lazy",              sampling.grammar_lazy},
             {"grammar_triggers",          grammar_triggers},
             {"preserved_tokens",          sampling.preserved_tokens},
-            {"chat_format",               common_chat_format_name(oaicompat_chat_format)},
+            {"chat_format",               common_chat_format_name(oaicompat_chat_syntax.format)},
+            {"reasoning_format",          (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
+            {"reasoning_in_content",      oaicompat_chat_syntax.reasoning_in_content},
+            {"thinking_forced_open",      oaicompat_chat_syntax.thinking_forced_open},
             {"samplers",                  samplers},
             {"speculative.n_max",         speculative.n_max},
             {"speculative.n_min",         speculative.n_min},
@@ -352,11 +356,14 @@ struct server_task {
         {
             auto it = data.find("chat_format");
             if (it != data.end()) {
-                params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
-                SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
+                params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
+                SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
             } else {
-                params.oaicompat_chat_format = defaults.oaicompat_chat_format;
+                params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
             }
+            params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
+            params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
+            params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
         }
 
         {
@@ -396,7 +403,14 @@ struct server_task {
                             params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
                         }
                     } else {
-                        params.sampling.grammar_triggers.push_back(std::move(ct.value));
+                        if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
+                            SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
+                        } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
+                            SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
+                        } else {
+                            throw std::runtime_error("Unknown grammar trigger type");
+                        }
+                        params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
                     }
                 }
             }
@@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
     slot_params generation_params;
 
     // OAI-compat fields
-    bool                  verbose                  = false;
-    oaicompat_type        oaicompat                = OAICOMPAT_TYPE_NONE;
-    std::string           oaicompat_model;
-    std::string           oaicompat_cmpl_id;
-    common_chat_format    oaicompat_chat_format    = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    bool               verbose                  = false;
+    oaicompat_type     oaicompat                = OAICOMPAT_TYPE_NONE;
+    std::string        oaicompat_model;
+    std::string        oaicompat_cmpl_id;
+    common_chat_msg    oaicompat_msg;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
 
     virtual int get_index() override {
         return index;
@@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
     json to_json_oaicompat_chat() {
         std::string finish_reason = "length";
         common_chat_msg msg;
-        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
-            SRV_DBG("Parsing chat message: %s\n", content.c_str());
-            msg = common_chat_parse(content, oaicompat_chat_format);
-            finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
+        if (!oaicompat_msg.empty()) {
+            msg = oaicompat_msg;
         } else {
+            msg.role = "assistant";
             msg.content = content;
         }
-
-        json message {
-            {"role", "assistant"},
-        };
-        if (!msg.reasoning_content.empty()) {
-            message["reasoning_content"] = msg.reasoning_content;
-        }
-        if (msg.content.empty() && !msg.tool_calls.empty()) {
-            message["content"] = json();
-        } else {
-            message["content"] = msg.content;
-        }
-        if (!msg.tool_calls.empty()) {
-            auto tool_calls = json::array();
-            for (const auto & tc : msg.tool_calls) {
-                tool_calls.push_back({
-                    {"type", "function"},
-                    {"function", {
-                        {"name", tc.name},
-                        {"arguments", tc.arguments},
-                    }},
-                    // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
-                    // We only generate a random id for the ones that don't generate one by themselves
-                    // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
-                    {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
-                });
-            }
-            message["tool_calls"] = tool_calls;
+        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+            finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
         }
 
         json choice {
             {"finish_reason", finish_reason},
             {"index", 0},
-            {"message", message},
+            {"message", msg.to_json_oaicompat<json>()},
         };
 
         if (!stream && probs_output.size() > 0) {
@@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
         std::time_t t = std::time(0);
         std::string finish_reason = "length";
         if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
-            finish_reason = "stop";
+            finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
+        }
+
+        json deltas = json::array();
+        for (const auto & diff : oaicompat_msg_diffs) {
+            deltas.push_back({
+                {"choices", json::array({
+                    json {
+                        {"finish_reason", nullptr},
+                        {"index", 0},
+                        {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
+                    },
+                })},
+                {"created", t},
+                {"id", oaicompat_cmpl_id},
+                {"model", oaicompat_model},
+                {"system_fingerprint", build_info},
+                {"object", "chat.completion.chunk"},
+            });
         }
 
-        json choice = json {
-            {"finish_reason", finish_reason},
-            {"index", 0},
-            {"delta", json::object()}
-        };
-
-        json ret = json {
-            {"choices",            json::array({choice})},
+        deltas.push_back({
+            {"choices", json::array({
+                json {
+                    {"finish_reason", finish_reason},
+                    {"index", 0},
+                    {"delta", json::object()},
+                },
+            })},
             {"created",            t},
             {"id",                 oaicompat_cmpl_id},
             {"model",              oaicompat_model},
@@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
                 {"prompt_tokens",     n_prompt_tokens},
                 {"total_tokens",      n_decoded + n_prompt_tokens},
             }},
-        };
+        });
 
         if (timings.prompt_n >= 0) {
-            ret.push_back({"timings", timings.to_json()});
+            deltas.back().push_back({"timings", timings.to_json()});
         }
 
         // extra fields for debugging purposes
-        if (verbose) {
-            ret["__verbose"] = to_json_non_oaicompat();
+        if (verbose && !deltas.empty()) {
+            deltas.front()["__verbose"] = to_json_non_oaicompat();
         }
 
-        return ret;
+        return deltas;
     }
 };
 
@@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
     result_timings timings;
 
     // OAI-compat fields
-    bool           verbose   = false;
-    oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
-    std::string    oaicompat_model;
-    std::string    oaicompat_cmpl_id;
+    bool            verbose   = false;
+    oaicompat_type  oaicompat = OAICOMPAT_TYPE_NONE;
+    std::string     oaicompat_model;
+    std::string     oaicompat_cmpl_id;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
 
     virtual int get_index() override {
         return index;
@@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
         std::time_t t = std::time(0);
         json choices;
 
+        std::vector<json> deltas;
+        auto add_delta = [&](const json & delta) {
+            deltas.push_back({
+                {"choices", json::array({
+                    json {
+                        {"finish_reason", nullptr},
+                        {"index", 0},
+                        {"delta", delta},
+                    },
+                })},
+                {"created", t},
+                {"id", oaicompat_cmpl_id},
+                {"model", oaicompat_model},
+                {"system_fingerprint", build_info},
+                {"object", "chat.completion.chunk"},
+            });
+        };
+        // We have to send an initial update to conform to openai behavior
         if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                // initial_ret is the role message for stream=True
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"},
-                                            {"content", ""}
-                                        }}}})},
-                            {"created", t},
-                            {"id", oaicompat_cmpl_id},
-                            {"model", oaicompat_model},
-                            {"system_fingerprint", build_info},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json {
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", oaicompat_cmpl_id},
-                            {"model", oaicompat_model},
-                            {"system_fingerprint", build_info},
-                            {"object", "chat.completion.chunk"}};
-
-                if (prob_output.probs.size() > 0) {
-                    second_ret["choices"][0]["logprobs"] = json{
-                        {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
-                    };
-                }
-
-                if (timings.prompt_n >= 0) {
-                    second_ret.push_back({"timings", timings.to_json()});
-                }
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json {
-                    {"content", content},
-                }},
-            }});
+            add_delta({
+                {"role", "assistant"},
+                {"content", nullptr},
+            });
         }
 
-        GGML_ASSERT(choices.size() >= 1);
-
-        if (prob_output.probs.size() > 0) {
-            choices[0]["logprobs"] = json{
-                {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
-            };
+        for (const auto & diff : oaicompat_msg_diffs) {
+            add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
         }
 
-        json ret = json {
-            {"choices",            choices},
-            {"created",            t},
-            {"id",                 oaicompat_cmpl_id},
-            {"model",              oaicompat_model},
-            {"system_fingerprint", build_info},
-            {"object",             "chat.completion.chunk"}
-        };
+        if (!deltas.empty()) {
+            GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
 
-        if (timings.prompt_n >= 0) {
-            ret.push_back({"timings", timings.to_json()});
+            if (prob_output.probs.size() > 0) {
+                deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
+                    {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+                };
+            }
+
+            if (timings.prompt_n >= 0) {
+                deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
+            }
         }
 
-        return std::vector<json>({ret});
+        return deltas;
     }
 };
 
@@ -1293,6 +1266,7 @@ struct server_slot {
 
     std::string  generated_text;
     llama_tokens generated_tokens;
+    common_chat_msg chat_msg;
 
     server_tokens cache_tokens;
 
@@ -1313,6 +1287,7 @@ struct server_slot {
     llama_token sampled;
 
     common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    std::vector<std::string> generated_tool_call_ids;
 
     // stats
     size_t n_sent_text        = 0; // number of sent text character
@@ -1342,9 +1317,13 @@ struct server_slot {
         n_past             = 0;
         n_sent_text        = 0;
         task_type          = SERVER_TASK_TYPE_COMPLETION;
+        chat_format        = COMMON_CHAT_FORMAT_CONTENT_ONLY;
 
         generated_tokens.clear();
         generated_token_probs.clear();
+        chat_msg = {};
+        json_schema = json();
+        generated_tool_call_ids.clear();
 
         // clear speculative decoding stats
         n_draft_total = 0;
@@ -1424,6 +1403,21 @@ struct server_slot {
         return timings;
     }
 
+    const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
+        auto previous_msg = chat_msg;
+        SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
+        auto new_msg = common_chat_parse(
+            generated_text,
+            /* is_partial= */ stop != STOP_TYPE_EOS,
+            params.oaicompat_chat_syntax);
+        if (!new_msg.empty()) {
+            new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
+            chat_msg = new_msg;
+            diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
+        }
+        return chat_msg;
+    }
+
     size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
         size_t stop_pos = std::string::npos;
 
@@ -2475,10 +2469,12 @@ struct server_context {
         res->n_prompt_tokens     = slot.n_prompt_tokens;
         res->post_sampling_probs = slot.params.post_sampling_probs;
 
-        res->verbose           = slot.params.verbose;
-        res->oaicompat         = slot.params.oaicompat;
-        res->oaicompat_model   = slot.params.oaicompat_model;
-        res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+        res->verbose               = slot.params.verbose;
+        res->oaicompat             = slot.params.oaicompat;
+        res->oaicompat_model       = slot.params.oaicompat_model;
+        res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
+
+        slot.update_chat_msg(res->oaicompat_msg_diffs);
 
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
@@ -2499,7 +2495,7 @@ struct server_context {
         res->id_slot         = slot.id;
 
         res->index           = slot.index;
-        res->content         = std::move(slot.generated_text);
+        res->content         = slot.generated_text;
         res->tokens          = std::move(slot.generated_tokens);
         res->timings         = slot.get_timings();
         res->prompt          = slot.prompt_tokens.detokenize(ctx, true);
@@ -2519,7 +2515,8 @@ struct server_context {
         res->oaicompat             = slot.params.oaicompat;
         res->oaicompat_model       = slot.params.oaicompat_model;
         res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
-        res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
+        res->oaicompat_msg         = slot.update_chat_msg(res->oaicompat_msg_diffs);
+
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
             if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
index bab5d005d96c29b28325a414d7102e580353ff97..1b5205f79d610b7206b9319188be1bd6c1424123 100644 (file)
@@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
         choice = data["choices"][0]
         if i == 0:
             # Check first role message for stream=True
-            assert choice["delta"]["content"] == ""
+            assert choice["delta"]["content"] is None
             assert choice["delta"]["role"] == "assistant"
         else:
             assert "role" not in choice["delta"]
@@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
             assert choice["finish_reason"] == finish_reason
         else:
             assert choice["finish_reason"] is None
-            content += choice["delta"]["content"]
+            content += choice["delta"]["content"] or ''
 
 
 def test_chat_completion_with_openai_library():
@@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
     for i, data in enumerate(res):
         if i == 0:
             # Check first role message for stream=True
-            assert data["choices"][0]["delta"]["content"] == ""
+            assert data["choices"][0]["delta"]["content"] is None
             assert data["choices"][0]["delta"]["role"] == "assistant"
+            assert "timings" not in data, f'First event should not have timings: {data}'
         else:
             assert "role" not in data["choices"][0]["delta"]
             assert "timings" in data
@@ -311,7 +312,7 @@ def test_logprobs_stream():
         choice = data.choices[0]
         if i == 0:
             # Check first role message for stream=True
-            assert choice.delta.content == ""
+            assert choice.delta.content is None
             assert choice.delta.role == "assistant"
         else:
             assert choice.delta.role is None
index 1f2c151c1a0fa6817920c60af88ac50c040b5a6c..610610749bd3499b3e169966ac05545231f78f6b 100755 (executable)
@@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
 sys.path.insert(0, str(path))
 
 from utils import *
+from enum import Enum
 
 server: ServerProcess
 
@@ -20,7 +21,11 @@ def create_server():
     server = ServerPreset.tinyllama2()
     server.model_alias = "tinyllama-2-tool-call"
     server.server_port = 8081
+    server.n_slots = 1
 
+class CompletionMode(Enum):
+    NORMAL = "normal"
+    STREAMED = "streamed"
 
 TEST_TOOL = {
     "type":"function",
@@ -73,9 +78,8 @@ WEATHER_TOOL = {
   }
 }
 
-
 def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
         "parallel_tool_calls": False,
         **kwargs,
     })
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
     assert expected_function_name == tool_call["function"]["name"]
     actual_arguments = tool_call["function"]["arguments"]
@@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
         assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
 
 
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,tool,argument_key", [
     ("google-gemma-2-2b-it",                          TEST_TOOL,            "success"),
+    ("google-gemma-2-2b-it",                          TEST_TOOL,            "success"),
+    ("meta-llama-Llama-3.3-70B-Instruct",             TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.3-70B-Instruct",             TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.3-70B-Instruct",             PYTHON_TOOL,          "code"),
+    ("meta-llama-Llama-3.3-70B-Instruct",             PYTHON_TOOL,          "code"),
 ])
-def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
+def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
     global server
     n_predict = 1024
     # server = ServerPreset.stories15m_moe()
@@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,tool,argument_key", [
     ("meta-llama-Llama-3.1-8B-Instruct",              TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.1-8B-Instruct",              PYTHON_TOOL,          "code"),
+
     ("meetkai-functionary-medium-v3.1",               TEST_TOOL,            "success"),
     ("meetkai-functionary-medium-v3.1",               PYTHON_TOOL,          "code"),
+
     ("meetkai-functionary-medium-v3.2",               TEST_TOOL,            "success"),
-    ("meetkai-functionary-medium-v3.2",               PYTHON_TOOL,          "code"),
+    # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
+    # ("meetkai-functionary-medium-v3.2",               PYTHON_TOOL,          "code"),
+
     ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL,            "success"),
     ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL,          "code"),
+
     ("meta-llama-Llama-3.2-3B-Instruct",              TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.2-3B-Instruct",              PYTHON_TOOL,          "code"),
+
     ("mistralai-Mistral-Nemo-Instruct-2407",          TEST_TOOL,            "success"),
     ("mistralai-Mistral-Nemo-Instruct-2407",          PYTHON_TOOL,          "code"),
+
     ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use",   TEST_TOOL,            "success"),
     ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use",   PYTHON_TOOL,          "code"),
+
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      TEST_TOOL,            "success"),
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      PYTHON_TOOL,          "code"),
+
     ("fireworks-ai-llama-3-firefunction-v2",          TEST_TOOL,            "success"),
+    # ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "codeFalse), True),
     # ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "code"),
+
 ])
-def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
+def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
     global server
     n_predict = 512
     # server = ServerPreset.stories15m_moe()
@@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
     (TEST_TOOL,    "success",  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
@@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   "chatml"),
 
-    (TEST_TOOL,    "success",  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+    (TEST_TOOL,    "success",  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/functionary-small-v3.2-GGUF:Q4_K_M",       ("meetkai/functionary-medium-v3.2", None)),
     (PYTHON_TOOL,  "code",     "bartowski/functionary-small-v3.2-GGUF:Q4_K_M",       ("meetkai/functionary-medium-v3.2", None)),
@@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     (TEST_TOOL,    "success",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
-def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
         "tool_choice": "required",
         "tools": [tool],
         "parallel_tool_calls": False,
+        "stream": stream == CompletionMode.STREAMED,
         "temperature": 0.0,
         "top_k": 1,
         "top_p": 1.0,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
@@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
 
 
 def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
         "tool_choice": tool_choice,
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
 
 
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [],            None),
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [TEST_TOOL],   None),
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [PYTHON_TOOL], 'none'),
 ])
-def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
+def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
     global server
-    server.jinja = True
     server.n_predict = n_predict
+    server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
     ("meetkai-functionary-medium-v3.2",               256, [],            None),
     ("meetkai-functionary-medium-v3.2",               256, [TEST_TOOL],   None),
@@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
     ("meta-llama-Llama-3.2-3B-Instruct",              256, [TEST_TOOL],   None),
     ("meta-llama-Llama-3.2-3B-Instruct",              256, [PYTHON_TOOL], 'none'),
 ])
-def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
+def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
     global server
-    server.jinja = True
     server.n_predict = n_predict
+    server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("hf_repo,template_override", [
     ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
     ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
     ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      "chatml"),
 
-    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
 
-    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)),
-    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
+    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)),
+    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
 
     ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      ("meta-llama/Llama-3.2-3B-Instruct", None)),
     ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      "chatml"),
@@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
 
     # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
 ])
-def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_weather(server, max_tokens=n_predict)
+    do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
 
 def do_test_weather(server: ServerProcess, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "messages": [
             {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
             {"role": "user", "content": "What is the weather in Istanbul?"},
@@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
         "tools": [WEATHER_TOOL],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     actual_arguments = json.loads(tool_call["function"]["arguments"])
     assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
     location = actual_arguments["location"]
@@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       "chatml"),
     (None,                                           128,  "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
@@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
     # (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
     # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
-def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192 * 2
     server.n_predict = n_predict
@@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_calc_result(server, result_override, n_predict)
+    do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
 def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
@@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
         ],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
     content = choice["message"].get("content")
@@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
 
 
 @pytest.mark.slow
-@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
-    (128, 'deepseek',  "^The sum of 102 and 7 is 109[\\s\\S]*",                        None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
-    (128,  None,        "^The sum of 102 and 7 is 109[\\s\\S]*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
-
-    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "I need to calculate the sum of 102 and 7[\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (1024, 'none',      "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-
-    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
+    (128, 'deepseek',   CompletionMode.NORMAL,   None, "^The sum of 102 and 7 is 109[\\s\\S]*",                                       "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (128,  None,        CompletionMode.NORMAL,   None, "^The sum of 102 and 7 is 109[\\s\\S]*",                                       "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (1024, 'deepseek',  CompletionMode.NORMAL,   "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'deepseek',  CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'deepseek',  CompletionMode.NORMAL,   "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*",                                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    (1024, 'deepseek',  CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*",              "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    # (1024, 'none',      CompletionMode.NORMAL,   None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*",                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    # (128,  'deepseek',  None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*",                      "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M",                None),
 ])
-def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
-    server.n_slots = 1
     server.reasoning_format = reasoning_format
     server.jinja = True
     server.n_ctx = 8192 * 2
@@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "user", "content": "What's the sum of 102 and 7?"},
-        ]
+        ],
+        "stream": stream == CompletionMode.STREAMED,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
 
     content = choice["message"].get("content")
@@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("hf_repo,template_override", [
     ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
@@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
     ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
     ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              "chatml"),
 ])
-def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512 # High because of DeepSeek R1
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
 
-    do_test_hello_world(server, max_tokens=n_predict)
+    do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
 
 def do_test_hello_world(server: ServerProcess, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "messages": [
             {"role": "system", "content": "You are a tool-calling agent."},
             {"role": "user", "content": "say hello world with python"},
@@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
         "tools": [PYTHON_TOOL],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     actual_arguments = json.loads(tool_call["function"]["arguments"])
     assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
     code = actual_arguments["code"]
     assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
-    assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
+    assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
index 27a0f0356aae1cbd47c4658195d25c3a92b09b88..b480801b17abbaebf585edbcd85eb601c9d13f41 100644 (file)
@@ -294,6 +294,77 @@ class ServerProcess:
                 print("Partial response from server", json.dumps(data, indent=2))
                 yield data
 
+    def make_any_request(
+        self,
+        method: str,
+        path: str,
+        data: dict | None = None,
+        headers: dict | None = None,
+        timeout: float | None = None,
+    ) -> dict:
+        stream = data.get('stream', False)
+        if stream:
+            content: list[str] = []
+            tool_calls: list[dict] = []
+            finish_reason: Optional[str] = None
+
+            content_parts = 0
+            tool_call_parts = 0
+            arguments_parts = 0
+
+            for chunk in self.make_stream_request(method, path, data, headers):
+                assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
+                choice = chunk['choices'][0]
+                if choice['delta'].get('content') is not None:
+                    assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
+                    content.append(choice['delta']['content'])
+                    content_parts += 1
+                if choice['delta'].get('finish_reason') is not None:
+                    finish_reason = choice['delta']['finish_reason']
+                for tc in choice['delta'].get('tool_calls', []):
+                    if 'function' not in tc:
+                        raise ValueError(f"Expected function type, got {tc['type']}")
+                    if tc['index'] >= len(tool_calls):
+                        tool_calls.append(dict(
+                            id="",
+                            type="function",
+                            function=dict(
+                                name="",
+                                arguments="",
+                            )
+                        ))
+                    tool_call = tool_calls[tc['index']]
+                    if tc.get('id') is not None:
+                        tool_call['id'] = tc['id']
+                    fct = tc['function']
+                    if fct.get('name') is not None:
+                        tool_call['function']['name'] = fct['name']
+                    if fct.get('arguments') is not None:
+                        assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
+                        tool_call['function']['arguments'] += fct['arguments']
+
+            print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
+            result = dict(
+                choices=[
+                    dict(
+                        index=0,
+                        finish_reason=finish_reason,
+                        message=dict(
+                            role='assistant',
+                            content=''.join(content) if content else None,
+                            tool_calls=tool_calls if tool_calls else None,
+                        ),
+                    )
+                ],
+            )
+            print("Final response from server", json.dumps(result, indent=2))
+            return result
+        else:
+            response = self.make_request(method, path, data, headers, timeout=timeout)
+            assert response.status_code == 200, f"Server returned error: {response.status_code}"
+            return response.body
+
+
 
 server_instances: Set[ServerProcess] = set()
 
index bb27b366ea2d673b4833ea958f33db13824c3bc1..91efcfef067726a4dd95c831a686e97b86214924 100644 (file)
@@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
 // other common utils
 //
 
-static bool ends_with(const std::string & str, const std::string & suffix) {
-    return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
-}
-
-static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
-    if (!text.empty() && !stop.empty()) {
-        const char text_last_char = text.back();
-        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
-            if (stop[char_index] == text_last_char) {
-                const std::string current_partial = stop.substr(0, char_index + 1);
-                if (ends_with(text, current_partial)) {
-                    return text.size() - char_index - 1;
-                }
-            }
-        }
-    }
-
-    return std::string::npos;
-}
-
 // TODO: reuse llama_detokenize
 template <class Iter>
 static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -599,19 +579,16 @@ static json oaicompat_chat_params_parse(
     json llama_params;
 
     auto tools = json_value(body, "tools", json());
+    auto has_tools = tools.is_array() && !tools.empty();
     auto stream = json_value(body, "stream", false);
+    auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
 
-    if (tools.is_array() && !tools.empty()) {
-        if (stream) {
-            throw std::runtime_error("Cannot use tools with stream");
-        }
-        if (!opt.use_jinja) {
+    if (!opt.use_jinja) {
+        if (has_tools) {
             throw std::runtime_error("tools param requires --jinja flag");
         }
-    }
-    if (!opt.use_jinja) {
-        if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
-            throw std::runtime_error("Unsupported param: tool_choice");
+        if (tool_choice != "auto") {
+            throw std::runtime_error("tool_choice param requires --jinja flag");
         }
     }
 
@@ -749,14 +726,12 @@ static json oaicompat_chat_params_parse(
     common_chat_templates_inputs inputs;
     inputs.messages              = common_chat_msgs_parse_oaicompat(messages);
     inputs.tools                 = common_chat_tools_parse_oaicompat(tools);
-    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(tool_choice);
     inputs.json_schema           = json_schema.is_null() ? "" : json_schema.dump();
     inputs.grammar               = grammar;
-    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
     inputs.use_jinja             = opt.use_jinja;
     inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
-    inputs.extract_reasoning     = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
-    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
+    inputs.reasoning_format      = opt.reasoning_format;
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
         throw std::runtime_error("Cannot use custom grammar constraints with tools.");
     }
@@ -774,7 +749,8 @@ static json oaicompat_chat_params_parse(
             throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
         }
 
-        inputs.extract_reasoning = false;
+        /* TODO: test this properly */
+        inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
         inputs.add_generation_prompt = true;
     }
 
@@ -799,6 +775,7 @@ static json oaicompat_chat_params_parse(
     }
     llama_params["grammar_triggers"] = grammar_triggers;
     llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+    llama_params["thinking_forced_open"]     = chat_params.thinking_forced_open;
     for (const auto & stop : chat_params.additional_stops) {
         llama_params["stop"].push_back(stop);
     }
@@ -812,6 +789,9 @@ static json oaicompat_chat_params_parse(
     // Handle "logprobs" field
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     if (json_value(body, "logprobs", false)) {
+        if (has_tools && stream) {
+            throw std::runtime_error("logprobs is not supported with tools + stream");
+        }
         llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
     } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
         throw std::runtime_error("top_logprobs requires logprobs to be set to true");