]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)
authorOlivier Chafik <redacted>
Tue, 18 Feb 2025 18:03:23 +0000 (18:03 +0000)
committerGitHub <redacted>
Tue, 18 Feb 2025 18:03:23 +0000 (18:03 +0000)
* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja

20 files changed:
Makefile
common/CMakeLists.txt
common/arg.cpp
common/chat-template.hpp [deleted file]
common/chat.cpp
common/chat.h [new file with mode: 0644]
common/chat.hpp [deleted file]
common/common.cpp
common/common.h
common/minja.hpp [deleted file]
common/minja/chat-template.hpp [new file with mode: 0644]
common/minja/minja.hpp [new file with mode: 0644]
examples/main/main.cpp
examples/run/run.cpp
examples/server/server.cpp
examples/server/tests/unit/test_chat_completion.py
examples/server/tests/unit/test_tool_call.py
examples/server/utils.hpp
tests/test-chat-template.cpp
tests/test-chat.cpp

index 662194086eaaf881e4f2a12286941f4251eb00b3..fb9a3b44890a08d6f34f33a80705073d41561f22 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1364,7 +1364,7 @@ llama-server: \
        examples/server/index.html.hpp \
        examples/server/loading.html.hpp \
        common/chat.cpp \
-       common/chat.hpp \
+       common/chat.h \
        common/chat-template.hpp \
        common/json.hpp \
        common/minja.hpp \
index c2b4aa7d09f1c78d78863d7dc406520a33707c18..17146fffc11685d6cfde7be474a66004673e4e36 100644 (file)
@@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
     arg.h
     base64.hpp
     chat.cpp
-    chat.hpp
-    chat-template.hpp
+    chat.h
     common.cpp
     common.h
     console.cpp
@@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
     llguidance.cpp
     log.cpp
     log.h
-    minja.hpp
+    minja/chat-template.hpp
+    minja/minja.hpp
     ngram-cache.cpp
     ngram-cache.h
     sampling.cpp
index f06aa1076cca70e01f507b49c116424801fa8ce2..eb8beccac2ee7e93bb6558d934a0b9cbde15bb1c 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "log.h"
 #include "sampling.h"
+#include "chat.h"
 
 #include <algorithm>
 #include <climits>
diff --git a/common/chat-template.hpp b/common/chat-template.hpp
deleted file mode 100644 (file)
index 882ba41..0000000
+++ /dev/null
@@ -1,529 +0,0 @@
-/*
-    Copyright 2024 Google LLC
-
-    Use of this source code is governed by an MIT-style
-    license that can be found in the LICENSE file or at
-    https://opensource.org/licenses/MIT.
-*/
-// SPDX-License-Identifier: MIT
-#pragma once
-
-#include "minja.hpp"
-#include <json.hpp>
-#include <string>
-#include <vector>
-
-using json = nlohmann::ordered_json;
-
-namespace minja {
-
-struct chat_template_caps {
-    bool supports_tools = false;
-    bool supports_tool_calls = false;
-    bool supports_tool_responses = false;
-    bool supports_system_role = false;
-    bool supports_parallel_tool_calls = false;
-    bool supports_tool_call_id = false;
-    // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
-    // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
-    bool requires_object_arguments = false;
-    // CohereForAI/c4ai-command-r-plus simple variant
-    bool requires_non_null_content = false;
-    // MiniMaxAI/MiniMax-Text-01 special
-    bool requires_typed_content = false;
-};
-
-struct chat_template_inputs {
-    nlohmann::ordered_json messages;
-    nlohmann::ordered_json tools;
-    bool add_generation_prompt = true;
-    nlohmann::ordered_json extra_context;
-    std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
-};
-
-struct chat_template_options {
-    bool apply_polyfills = true;
-    bool use_bos_token = true;
-    bool use_eos_token = true;
-    bool define_strftime_now = true;
-
-    bool polyfill_tools = true;
-    bool polyfill_tool_call_examples = true;
-    bool polyfill_tool_calls = true;
-    bool polyfill_tool_responses = true;
-    bool polyfill_system_role = true;
-    bool polyfill_object_arguments = true;
-    bool polyfill_typed_content = true;
-};
-
-class chat_template {
-
-  private:
-    chat_template_caps caps_;
-    std::string source_;
-    std::string bos_token_;
-    std::string eos_token_;
-    std::shared_ptr<minja::TemplateNode> template_root_;
-    std::string tool_call_example_;
-
-    std::string try_raw_render(
-        const nlohmann::ordered_json & messages,
-        const nlohmann::ordered_json & tools,
-        bool add_generation_prompt,
-        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
-    {
-        try {
-            chat_template_inputs inputs;
-            inputs.messages = messages;
-            inputs.tools = tools;
-            inputs.add_generation_prompt = add_generation_prompt;
-            inputs.extra_context = extra_context;
-            // Use fixed date for tests
-            inputs.now = std::chrono::system_clock::from_time_t(0);
-
-            chat_template_options opts;
-            opts.apply_polyfills = false;
-
-            auto prompt = apply(inputs, opts);
-            // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
-            return prompt;
-        } catch (const std::exception & e) {
-            // fprintf(stderr, "try_raw_render error: %s\n", e.what());
-            return "";
-        }
-    }
-
-  public:
-
-    chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
-        : source_(source), bos_token_(bos_token), eos_token_(eos_token)
-    {
-        template_root_ = minja::Parser::parse(source_, {
-            /* .trim_blocks = */ true,
-            /* .lstrip_blocks = */ true,
-            /* .keep_trailing_newline = */ false,
-        });
-
-        auto contains = [](const std::string & haystack, const std::string & needle) {
-            return haystack.find(needle) != std::string::npos;
-        };
-
-        const std::string user_needle = "<User Needle>";
-        const std::string sys_needle = "<System Needle>";
-        const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
-        const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
-
-        caps_.requires_typed_content =
-            !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
-            && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
-
-        const auto dummy_user_msg = caps_.requires_typed_content
-            ? dummy_typed_user_msg
-            : dummy_str_user_msg;
-        const json needle_system_msg = {
-            {"role", "system"},
-            {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
-        };
-
-        caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
-
-        auto out = try_raw_render(json::array({
-            dummy_user_msg
-        }), json::array({
-            {
-                {"name", "some_tool"},
-                {"type", "function"},
-                {"function", {
-                    {"name", "some_tool"},
-                    {"description", "Some tool."},
-                    {"parameters", {
-                        {"type", "object"},
-                        {"properties", {
-                            {"arg", {
-                                {"type", "string"},
-                                {"description", "Some argument."},
-                            }},
-                        }},
-                        {"required", json::array({ "arg" })},
-                    }},
-                }},
-            },
-        }), false);
-        caps_.supports_tools = contains(out, "some_tool");
-
-        auto make_tool_calls_msg = [&](const json & tool_calls) {
-            return json {
-                {"role", "assistant"},
-                {"content", nullptr},
-                {"tool_calls", tool_calls},
-            };
-        };
-        auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
-            return json {
-                {"id", "call_1___"},
-                {"type", "function"},
-                {"function", {
-                    {"arguments", arguments},
-                    {"name", tool_name},
-                }},
-            };
-        };
-        const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
-
-        // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
-        out = try_raw_render(json::array({
-            dummy_user_msg,
-            make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
-        }), {}, false);
-        auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
-        out = try_raw_render(json::array({
-            dummy_user_msg,
-            make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
-        }), {}, false);
-        auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
-
-        caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
-        caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
-        auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
-        auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
-        caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
-
-        if (caps_.supports_tool_calls) {
-            auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
-            auto tc1 = make_tool_call("test_tool1", dummy_args);
-            auto tc2 = make_tool_call("test_tool2", dummy_args);
-            auto out = try_raw_render(json::array({
-                dummy_user_msg,
-                make_tool_calls_msg(json::array({tc1, tc2})),
-            }), {}, false);
-            caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
-
-            out = try_raw_render(json::array({
-                dummy_user_msg,
-                make_tool_calls_msg(json::array({tc1})),
-                {
-                    {"role", "tool"},
-                    {"name", "test_tool1"},
-                    {"content", "Some response!"},
-                    {"tool_call_id", "call_911_"},
-                }
-            }), {}, false);
-            caps_.supports_tool_responses = contains(out, "Some response!");
-            caps_.supports_tool_call_id = contains(out, "call_911_");
-        }
-
-        try {
-            if (!caps_.supports_tools) {
-                const json user_msg {
-                    {"role", "user"},
-                    {"content", "Hey"},
-                };
-                const json args {
-                    {"arg1", "some_value"},
-                };
-                const json tool_call_msg {
-                    {"role", "assistant"},
-                    {"content", nullptr},
-                    {"tool_calls", json::array({
-                        {
-                            // TODO: detect if requires numerical id or fixed length == 6 like Nemo
-                            {"id", "call_1___"},
-                            {"type", "function"},
-                            {"function", {
-                                {"name", "tool_name"},
-                                {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
-                            }},
-                        },
-                    })},
-                };
-                std::string prefix, full;
-                {
-                    chat_template_inputs inputs;
-                    inputs.messages = json::array({user_msg});
-                    inputs.add_generation_prompt = true;
-                    prefix = apply(inputs);
-                }
-                {
-                    chat_template_inputs inputs;
-                    inputs.messages = json::array({user_msg, tool_call_msg});
-                    inputs.add_generation_prompt = false;
-                    full = apply(inputs);
-                }
-                auto eos_pos_last = full.rfind(eos_token_);
-                if (eos_pos_last == prefix.size() - eos_token_.size() ||
-                      (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
-                    full = full.substr(0, eos_pos_last);
-                }
-                size_t common_prefix_length = 0;
-                for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
-                    if (prefix[i] != full[i]) {
-                        break;
-                    }
-                    if (prefix[i] == '<') {
-                        // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
-                        // but it removes thinking tags for past messages.
-                        // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
-                        continue;
-                    }
-                    common_prefix_length = i + 1;
-                }
-                auto example = full.substr(common_prefix_length);
-                if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
-                    fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
-                } else {
-                    tool_call_example_ = example;
-                }
-            }
-        } catch (const std::exception & e) {
-            fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
-        }
-    }
-
-    const std::string & source() const { return source_; }
-    const std::string & bos_token() const { return bos_token_; }
-    const std::string & eos_token() const { return eos_token_; }
-    const chat_template_caps & original_caps() const { return caps_; }
-
-    // Deprecated, please use the form with chat_template_inputs and chat_template_options
-    std::string apply(
-        const nlohmann::ordered_json & messages,
-        const nlohmann::ordered_json & tools,
-        bool add_generation_prompt,
-        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
-        bool apply_polyfills = true)
-    {
-        fprintf(stderr, "[%s] Deprecated!\n", __func__);
-        chat_template_inputs inputs;
-        inputs.messages = messages;
-        inputs.tools = tools;
-        inputs.add_generation_prompt = add_generation_prompt;
-        inputs.extra_context = extra_context;
-        inputs.now = std::chrono::system_clock::now();
-
-        chat_template_options opts;
-        opts.apply_polyfills = apply_polyfills;
-
-        return apply(inputs, opts);
-    }
-
-    std::string apply(
-        const chat_template_inputs & inputs,
-        const chat_template_options & opts = chat_template_options()) const
-    {
-        json actual_messages;
-
-        auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
-        auto has_tool_calls = false;
-        auto has_tool_responses = false;
-        auto has_string_content = false;
-        for (const auto & message : inputs.messages) {
-            if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
-                has_tool_calls = true;
-            }
-            if (message.contains("role") && message["role"] == "tool") {
-                has_tool_responses = true;
-            }
-            if (message.contains("content") && message["content"].is_string()) {
-                has_string_content = true;
-            }
-        }
-
-        auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
-        auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
-        auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
-        auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
-        auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
-        auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
-        auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
-
-        auto needs_polyfills = opts.apply_polyfills && (false
-            || polyfill_system_role
-            || polyfill_tools
-            || polyfill_tool_calls
-            || polyfill_tool_responses
-            || polyfill_object_arguments
-            || polyfill_typed_content
-        );
-
-        if (needs_polyfills) {
-            actual_messages = json::array();
-
-            auto add_message = [&](const json & msg) {
-                if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
-                    actual_messages.push_back({
-                        {"role", msg.at("role")},
-                        {"content", {{
-                            {"type", "text"},
-                            {"text", msg.at("content")},
-                        }}},
-                    });
-                } else {
-                    actual_messages.push_back(msg);
-                }
-            };
-
-            std::string pending_system;
-            auto flush_sys = [&]() {
-                if (!pending_system.empty()) {
-                    add_message({
-                        {"role", "user"},
-                        {"content", pending_system},
-                    });
-                    pending_system.clear();
-                }
-            };
-
-            json adjusted_messages;
-            if (polyfill_tools) {
-                adjusted_messages = add_system(inputs.messages,
-                    "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
-                    (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
-            } else {
-                adjusted_messages = inputs.messages;
-            }
-
-            for (const auto & message_ : adjusted_messages) {
-                auto message = message_;
-                if (!message.contains("role") || !message.contains("content")) {
-                    throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
-                }
-                std::string role = message.at("role");
-
-                if (message.contains("tool_calls")) {
-                    if (polyfill_object_arguments || polyfill_tool_calls) {
-                        for (auto & tool_call : message.at("tool_calls")) {
-                            if (tool_call["type"] == "function") {
-                                auto & function = tool_call.at("function");
-                                auto & arguments = function.at("arguments");
-                                if (arguments.is_string()) {
-                                    try {
-                                        arguments = json::parse(arguments.get<std::string>());
-                                    } catch (const std::exception & ecvt) {
-                                        fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
-                                    }
-                                }
-                            }
-                        }
-                    }
-                    if (polyfill_tool_calls) {
-                        auto content = message.at("content");
-                        auto tool_calls = json::array();
-                        for (const auto & tool_call : message.at("tool_calls")) {
-                            if (tool_call.at("type") != "function") {
-                                continue;
-                            }
-                            const auto & function = tool_call.at("function");
-                            auto tc = json {
-                                {"name", function.at("name")},
-                                {"arguments", function.at("arguments")},
-                            };
-                            if (tool_call.contains("id")) {
-                                tc["id"] = tool_call["id"];
-                            }
-                            tool_calls.push_back(tc);
-                        }
-                        auto obj = json {
-                            {"tool_calls", tool_calls},
-                        };
-                        if (!content.is_null() && content != "") {
-                            obj["content"] = content;
-                        }
-                        message["content"] = obj.dump(2);
-                        message.erase("tool_calls");
-                    }
-                }
-                if (polyfill_tool_responses && role == "tool") {
-                    message["role"] = "user";
-                    auto obj = json {
-                        {"tool_response", {
-                            {"content", message.at("content")},
-                        }},
-                    };
-                    if (message.contains("name")) {
-                        obj["tool_response"]["name"] = message.at("name");
-                    }
-                    if (message.contains("tool_call_id")) {
-                        obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
-                    }
-                    message["content"] = obj.dump(2);
-                    message.erase("name");
-                }
-
-                if (!message["content"].is_null() && polyfill_system_role) {
-                    std::string content = message.at("content");
-                    if (role == "system") {
-                        if (!pending_system.empty()) pending_system += "\n";
-                        pending_system += content;
-                        continue;
-                    } else {
-                        if (role == "user") {
-                            if (!pending_system.empty()) {
-                                message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
-                                pending_system.clear();
-                            }
-                        } else {
-                            flush_sys();
-                        }
-                    }
-                }
-                add_message(message);
-            }
-            flush_sys();
-        } else {
-            actual_messages = inputs.messages;
-        }
-
-        auto context = minja::Context::make(json({
-            {"messages", actual_messages},
-            {"add_generation_prompt", inputs.add_generation_prompt},
-        }));
-        context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
-        context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
-        if (opts.define_strftime_now) {
-            auto now = inputs.now;
-            context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
-                args.expectArgs("strftime_now", {1, 1}, {0, 0});
-                auto format = args.args[0].get<std::string>();
-
-                auto time = std::chrono::system_clock::to_time_t(now);
-                auto local_time = *std::localtime(&time);
-                std::ostringstream ss;
-                ss << std::put_time(&local_time, format.c_str());
-                return ss.str();
-            }));
-        }
-        if (!inputs.tools.is_null()) {
-            context->set("tools", minja::Value(inputs.tools));
-        }
-        if (!inputs.extra_context.is_null()) {
-            for (auto & kv : inputs.extra_context.items()) {
-                context->set(kv.key(), minja::Value(kv.value()));
-            }
-        }
-
-        auto ret = template_root_->render(context);
-        // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
-        // fprintf(stderr, "apply: %s\n\n", ret.c_str());
-        return ret;
-    }
-
-    static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
-        json messages_with_system = messages;
-
-        if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
-            std::string existing_system = messages_with_system.at(0).at("content");
-            messages_with_system[0] = json {
-                {"role", "system"},
-                {"content", existing_system + "\n\n" + system_prompt},
-            };
-        } else {
-            messages_with_system.insert(messages_with_system.begin(), json {
-                {"role", "system"},
-                {"content", system_prompt},
-            });
-        }
-        return messages_with_system;
-    }
-};
-
-}  // namespace minja
index f21a9d2a63a4b06b1690d04c6479fe3001a1bb0c..9ebe4c5784cbcf98014cde575fc9295b6fa08cbc 100644 (file)
@@ -1,8 +1,433 @@
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
 #include "json-schema-to-grammar.h"
 #include "log.h"
-#include "minja.hpp"
+#include "minja/chat-template.hpp"
+#include "minja/minja.hpp"
+
+#include <optional>
+
+typedef minja::chat_template common_chat_template;
+
+struct common_chat_templates {
+    bool has_explicit_template; // Model had builtin template or template overridde was specified.
+    std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
+    std::unique_ptr<common_chat_template> template_tool_use;
+};
+
+struct templates_params {
+    json messages;
+    json tools;
+    common_chat_tool_choice tool_choice;
+    json json_schema;
+    bool parallel_tool_calls;
+    bool stream;
+    std::string grammar;
+    bool add_generation_prompt = true;
+    bool extract_reasoning     = true;
+};
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
+    if (tool_choice == "auto") {
+        return COMMON_CHAT_TOOL_CHOICE_AUTO;
+    }
+    if (tool_choice == "none") {
+        return COMMON_CHAT_TOOL_CHOICE_NONE;
+    }
+    if (tool_choice == "required") {
+        return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+    }
+    throw std::runtime_error("Invalid tool_choice: " + tool_choice);
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
+    std::vector<common_chat_msg> msgs;
+
+    try {
+
+        if (!messages.is_array()) {
+            throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
+        }
+
+        for (const auto & message : messages) {
+            if (!message.is_object()) {
+                throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
+            }
+
+            common_chat_msg msg;
+            if (!message.contains("role")) {
+                throw std::runtime_error("Missing 'role' in message: " + message.dump());
+            }
+            msg.role = message.at("role");
+
+            if (message.contains("content")) {
+                const auto & content = message.at("content");
+                if (content.is_string()) {
+                    msg.content = content;
+                } else if (content.is_array()) {
+                    for (const auto & part : content) {
+                        if (!part.contains("type")) {
+                            throw std::runtime_error("Missing content part type: " + part.dump());
+                        }
+                        const auto & type = part.at("type");
+                        if (type != "text") {
+                            throw std::runtime_error("Unsupported content part type: " + type.dump());
+                        }
+                        common_chat_msg_content_part msg_part;
+                        msg_part.type = type;
+                        msg_part.text = part.at("text");
+                        msg.content_parts.push_back(msg_part);
+                    }
+                } else if (!content.is_null()) {
+                    throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+                }
+            } else {
+                throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+            }
+            if (message.contains("reasoning_content")) {
+                msg.reasoning_content = message.at("reasoning_content");
+            }
+            if (message.contains("name")) {
+                msg.tool_name = message.at("name");
+            }
+            if (message.contains("tool_call_id")) {
+                msg.tool_call_id = message.at("tool_call_id");
+            }
+            if (message.contains("tool_calls")) {
+                for (const auto & tool_call : message.at("tool_calls")) {
+                    common_chat_tool_call tc;
+                    if (!tool_call.contains("type")) {
+                        throw std::runtime_error("Missing tool call type: " + tool_call.dump());
+                    }
+                    const auto & type = tool_call.at("type");
+                    if (type != "function") {
+                        throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
+                    }
+                    if (!tool_call.contains("function")) {
+                        throw std::runtime_error("Missing tool call function: " + tool_call.dump());
+                    }
+                    const auto & fc = tool_call.at("function");
+                    if (!fc.contains("name")) {
+                        throw std::runtime_error("Missing tool call name: " + tool_call.dump());
+                    }
+                    tc.name = fc.at("name");
+                    tc.arguments = fc.at("arguments");
+                    if (tool_call.contains("id")) {
+                        tc.id = tool_call.at("id");
+                    }
+                    msg.tool_calls.push_back(tc);
+                }
+            }
+
+            msgs.push_back(msg);
+        }
+    } catch (const std::exception & e) {
+        throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
+    }
+
+    return msgs;
+}
+
+template <>
+json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
+    json messages = json::array();
+    for (const auto & msg : msgs) {
+        if (!msg.content.empty() && !msg.content_parts.empty()) {
+            throw std::runtime_error("Cannot specify both content and content_parts");
+        }
+        json jmsg {
+            {"role", msg.role},
+        };
+        if (!msg.content.empty()) {
+            jmsg["content"] = msg.content;
+        } else if (!msg.content_parts.empty()) {
+            if (concat_typed_text) {
+                std::string text;
+                for (const auto & part : msg.content_parts) {
+                    if (part.type != "text") {
+                        LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
+                        continue;
+                    }
+                    if (!text.empty()) {
+                        text += '\n';
+                    }
+                    text += part.text;
+                }
+                jmsg["content"] = text;
+            } else {
+                auto & parts = jmsg["content"] = json::array();
+                for (const auto & part : msg.content_parts) {
+                    parts.push_back({
+                        {"type", part.type},
+                        {"text", part.text},
+                    });
+                }
+            }
+        } else {
+            jmsg["content"] = json(); // null
+        }
+        if (!msg.reasoning_content.empty()) {
+            jmsg["reasoning_content"] = msg.reasoning_content;
+        }
+        if (!msg.tool_name.empty()) {
+            jmsg["name"] = msg.tool_name;
+        }
+        if (!msg.tool_call_id.empty()) {
+            jmsg["tool_call_id"] = msg.tool_call_id;
+        }
+        if (!msg.tool_calls.empty()) {
+            auto & tool_calls = jmsg["tool_calls"] = json::array();
+            for (const auto & tool_call : msg.tool_calls) {
+                json tc {
+                    {"type", "function"},
+                    {"function", {
+                        {"name", tool_call.name},
+                        {"arguments", tool_call.arguments},
+                    }},
+                };
+                if (!tool_call.id.empty()) {
+                    tc["id"] = tool_call.id;
+                }
+                tool_calls.push_back(tc);
+            }
+        }
+        messages.push_back(jmsg);
+    }
+    return messages;
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
+    return common_chat_msgs_parse_oaicompat(json::parse(messages));
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
+    std::vector<common_chat_tool> result;
+
+    try {
+        if (!tools.is_null()) {
+            if (!tools.is_array()) {
+                throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
+            }
+            for (const auto & tool : tools) {
+                if (!tool.contains("type")) {
+                    throw std::runtime_error("Missing tool type: " + tool.dump());
+                }
+                const auto & type = tool.at("type");
+                if (!type.is_string() || type != "function") {
+                    throw std::runtime_error("Unsupported tool type: " + tool.dump());
+                }
+                if (!tool.contains("function")) {
+                    throw std::runtime_error("Missing tool function: " + tool.dump());
+                }
+
+                const auto & function = tool.at("function");
+                result.push_back({
+                    /* .name = */ function.at("name"),
+                    /* .description = */ function.at("description"),
+                    /* .parameters = */ function.at("parameters").dump(),
+                });
+            }
+        }
+    } catch (const std::exception & e) {
+        throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
+    }
+
+    return result;
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
+    return common_chat_tools_parse_oaicompat(json::parse(tools));
+}
+
+template <>
+json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
+    if (tools.empty()) {
+        return json();
+    }
+
+    auto result = json::array();
+    for (const auto & tool : tools) {
+        result.push_back({
+            {"type", "function"},
+            {"function", {
+                {"name", tool.name},
+                {"description", tool.description},
+                {"parameters", json::parse(tool.parameters)},
+            }},
+        });
+    }
+    return result;
+}
+
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
+    if (use_jinja) {
+        try {
+            common_chat_msg msg;
+            msg.role = "user";
+            msg.content = "test";
+
+            auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
+
+            common_chat_templates_inputs inputs;
+            inputs.messages = {msg};
+
+            common_chat_templates_apply(tmpls.get(), inputs);
+            return true;
+        } catch (const std::exception & e) {
+            LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
+            return false;
+        }
+    }
+    llama_chat_message chat[] = {{"user", "test"}};
+    const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
+    return res >= 0;
+}
+
+std::string common_chat_format_single(
+        const struct common_chat_templates * tmpls,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass,
+        bool use_jinja) {
+
+    common_chat_templates_inputs inputs;
+    inputs.use_jinja = use_jinja;
+
+    std::string fmt_past_msg;
+    if (!past_msg.empty()) {
+        inputs.messages = past_msg;
+        inputs.add_generation_prompt = false;
+        fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+    }
+    std::ostringstream ss;
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
+    inputs.messages.push_back(new_msg);
+    inputs.add_generation_prompt = add_ass;
+    auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
+}
+
+std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
+    common_chat_templates_inputs inputs;
+    inputs.use_jinja = use_jinja;
+    auto add_simple_msg = [&](auto role, auto content) {
+        common_chat_msg msg;
+        msg.role = role;
+        msg.content = content;
+        inputs.messages.push_back(msg);
+    };
+    add_simple_msg("system",    "You are a helpful assistant");
+    add_simple_msg("user",      "Hello");
+    add_simple_msg("assistant", "Hi there");
+    add_simple_msg("user",      "How are you?");
+    return common_chat_templates_apply(tmpls, inputs).prompt;
+}
+
+#define CHATML_TEMPLATE_SRC \
+    "{%- for message in messages -%}\n" \
+    "  {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
+    "{%- endfor -%}\n" \
+    "{%- if add_generation_prompt -%}\n" \
+    "  {{- '<|im_start|>assistant\n' -}}\n" \
+    "{%- endif -%}"
+
+void common_chat_templates_free(struct common_chat_templates * tmpls) {
+    delete tmpls;
+}
+
+bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
+    return tmpls->has_explicit_template;
+}
+
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
+    if (variant != nullptr) {
+        if (strcmp(variant, "tool_use") == 0) {
+            if (tmpls->template_tool_use) {
+                return tmpls->template_tool_use->source().c_str();
+            }
+            return nullptr;
+        } else {
+            LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
+        }
+    }
+    return tmpls->template_default->source().c_str();
+}
+
+common_chat_templates_ptr common_chat_templates_init(
+    const struct llama_model * model,
+    const std::string & chat_template_override,
+    const std::string & bos_token_override,
+    const std::string & eos_token_override)
+{
+    std::string default_template_src;
+    std::string template_tool_use_src;
+
+    bool has_explicit_template = !chat_template_override.empty();
+    if (chat_template_override.empty()) {
+        GGML_ASSERT(model != nullptr);
+        const auto * str = llama_model_chat_template(model, /* name */ nullptr);
+        if (str) {
+            default_template_src = str;
+            has_explicit_template = true;
+        }
+        str = llama_model_chat_template(model, /* name */ "tool_use");
+        if (str) {
+            template_tool_use_src = str;
+            has_explicit_template = true;
+        }
+    } else {
+        default_template_src = chat_template_override;
+    }
+    if (default_template_src.empty() || default_template_src == "chatml") {
+        if (!template_tool_use_src.empty()) {
+            default_template_src = template_tool_use_src;
+        } else {
+            default_template_src = CHATML_TEMPLATE_SRC;
+        }
+    }
+    std::string token_bos = bos_token_override;
+    std::string token_eos = eos_token_override;
+    if (model) {
+        const auto * vocab = llama_model_get_vocab(model);
+        const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
+            if (token == LLAMA_TOKEN_NULL) {
+                if (default_template_src.find(jinja_variable_name) != std::string::npos
+                    || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
+                    LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
+                }
+                return std::string();
+            }
+            return common_token_to_piece(vocab, token, true);
+        };
+        token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
+        token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
+    }
+    common_chat_templates_ptr tmpls(new common_chat_templates());
+    tmpls->has_explicit_template = has_explicit_template;
+    try {
+        tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
+        tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
+    }
+    if (!template_tool_use_src.empty()) {
+        try {
+            tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
+        } catch (const std::exception & e) {
+            LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
+        }
+    }
+    return tmpls;
+}
 
 std::string common_chat_format_name(common_chat_format format) {
     switch (format) {
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
 
         json_error_locator() : position(0), found_error(false) {}
 
-        bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
+        bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
             this->position = position - 1;
             this->found_error = true;
             return false;
         }
-        bool null() override { return true; }
-        bool boolean(bool) override { return true; }
-        bool number_integer(number_integer_t) override { return true; }
-        bool number_unsigned(number_unsigned_t) override { return true; }
-        bool number_float(number_float_t, const string_t &) override { return true; }
-        bool string(string_t &) override { return true; }
-        bool binary(binary_t &) override { return true; }
-        bool start_object(std::size_t) override { return true; }
-        bool key(string_t &) override { return true; }
+        bool null() override { return true; } // NOLINT
+        bool boolean(bool) override { return true; } // NOLINT
+        bool number_integer(number_integer_t) override { return true; } // NOLINT
+        bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
+        bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
+        bool string(string_t &) override { return true; } // NOLINT
+        bool binary(binary_t &) override { return true; } // NOLINT
+        bool start_object(std::size_t) override { return true; } // NOLINT
+        bool key(string_t &) override { return true; } // NOLINT
         bool end_object() override { return true; }
-        bool start_array(std::size_t) override { return true; }
+        bool start_array(std::size_t) override { return true; } // NOLINT
         bool end_array() override { return true; }
     };
     json_error_locator err_loc;
@@ -187,13 +612,20 @@ static std::string apply(
     // tmpl_inputs.now = std::chrono::system_clock::now();
 
     minja::chat_template_options tmpl_opts;
-    tmpl_opts.use_bos_token = false;
-    tmpl_opts.use_eos_token = false;
-
-    return tmpl.apply(tmpl_inputs, tmpl_opts);
+    // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
+    // instead of using `chat_template_options.use_bos_token = false`, since these tokens
+    // may be needed inside the template / between messages too.
+    auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
+    if (string_starts_with(result, tmpl.bos_token())) {
+        result = result.substr(tmpl.bos_token().size());
+    }
+    if (string_ends_with(result, tmpl.eos_token())) {
+        result = result.substr(0, result.size() - tmpl.eos_token().size());
+    }
+    return result;
 }
 
-static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
 
     auto tool_call_schemas = json::array();
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
                 {"required", json::array({"tool_call"})},
             };
     const auto schema =
-        inputs.tool_choice != "required"
+        inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
             ? json {
                 {"anyOf", json::array({
                     tool_call,
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
     return result;
 }
 
-static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         auto schemas = json::array();
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
     return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
 }
 
-static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         auto schemas = json::array();
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
     const auto & parameters_required = parameters.at("required");
     for (const auto & prop : expected_properties) {
         if (!parameters_properties.contains(prop)) {
-            throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
+            throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
         }
         if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
-            throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
+            throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
         }
     }
     if (parameters_properties.size() != expected_properties.size()) {
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
     }
 }
 
-static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
+static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
     auto builtin_tools = json::array();
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
 
         auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
-            if (name == "wolfram_alpha") {
+            if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
                 // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
-                expect_tool_parameters(name, parameters, {"query"});
-            } else if (name == "web_search" || name == "brave_search") {
                 // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
                 expect_tool_parameters(name, parameters, {"query"});
             } else if (name == "python" || name == "code_interpreter") {
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
 
             std::vector<std::string> kvs;
             for (const auto & [key, value] : parameters.at("properties").items()) {
-                kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
+                kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
             }
 
             tool_rules.push_back(
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
             auto arg_value_str = raw_args.substr(it_eq + 1);
             auto arg_value = json::parse(arg_value_str);
 
-            return {
-                /* .role = */ "assistant",
-                /* .content = */ match.prefix().str(),
-                /* .tool_calls = */ {
-                    {
-                        /* .name = */ match[1],
-                        /* .arguments = */ (json {
-                            {arg_name, arg_value},
-                        }).dump(),
-                        /* .id = */ "",
-                    },
-                },
-            };
+            common_chat_msg msg;
+            msg.role = "assistant";
+            msg.content = match.prefix().str();
+            msg.tool_calls.push_back({
+                /* .name = */ name,
+                /* .arguments = */ (json {
+                    {arg_name, arg_value},
+                }).dump(),
+                /* .id = */ "",
+            });
+            return msg;
         }
     }
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
 }
 
-static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             std::vector<std::string> tool_rules;
             foreach_function(inputs.tools, [&](const json & tool) {
                 const auto & function = tool.at("function");
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
                 auto args_rule = builder.add_schema(name + "-args", parameters);
                 tool_rules.push_back(builder.add_rule(name + "-call",
                     "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
     return msg;
 }
 
-static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
-    fprintf(stderr, "%s\n", __func__);
+static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
+    LOG_DBG("%s\n", __func__);
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
         {"datetime", "Jan 29 2025 13:00:00 GMT"},
         {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
     });
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required";
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             auto schemas = json::array();
             foreach_function(inputs.tools, [&](const json & tool) {
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
     return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
 }
 
-static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
     // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required";
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             std::vector<std::string> first_tool_rules;
             std::vector<std::string> subsequent_tool_rules;
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
                 const auto & function = tool.at("function");
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
                 auto args_rule = builder.add_schema(name + "-args", parameters);
                 first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
                 subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
     }
 }
 
-static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
     common_chat_params data;
     json tools = inputs.tools.is_null() ? inputs.tools : json::array();
     std::string python_code_argument_name;
     auto has_raw_python = false;
 
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
                     throw std::runtime_error("Missing type in python tool");
                 }
                 has_raw_python = true;
-                auto type = parameters.at("type");
+                const auto & type = parameters.at("type");
                 if (type == "object") {
                     auto properties = parameters.at("properties");
                     for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
     std::smatch match;
     if (std::regex_search(input, match, python_tag_regex)) {
         auto code = match[1].str();
-        return {
-            /* .role = */ "assistant",
-            /* .content = */ match.prefix().str(),
-            /* .tool_calls = */ {
-                {
-                    /* .name = */ "python",
-                    /* .arguments = */ (json {{"code", code}}).dump(),
-                    /* .id = */ "",
-                },
-            }
-        };
+        common_chat_msg msg;
+        msg.role = "assistant";
+        msg.content = match.prefix().str();
+        msg.tool_calls.push_back({
+            /* .name = */ "python",
+            /* .arguments = */ (json {{"code", code}}).dump(),
+            /* .id = */ "",
+        });
+        return msg;
     }
     static std::regex function_regex(R"(<function=(\w+)>)");
     static std::regex close_regex(R"(</function>)");
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
 }
 
-static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
         std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
         std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
 
+        common_chat_msg msg;
+        msg.role = "assistant";
+
         auto end = input.end();
         std::sregex_iterator rend;
         std::sregex_iterator rit(input.begin(), end, start_pattern);
         if (rit == rend) {
-            return {
-                /* .role = */ "assistant",
-                /* .content = */ input,
-                /* .tool_calls = */ {},
-            };
+            msg.content = input;
+            return msg;
         }
 
-        common_chat_msg result;
-        result.role = "assistant";
-        result.content = rit->prefix();
+        msg.content = rit->prefix();
 
         auto it = rit->suffix().first;
         while (it != end) {
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
                 throw std::runtime_error("Failed to parse json tool call");
             }
             const auto & arguments = call.at("arguments");
-            result.tool_calls.push_back({
+            msg.tool_calls.push_back({
                 call.at("name"),
                 arguments.dump(),
                 // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
                 break;
             }
         }
-        return result;
+        return msg;
     } catch (const std::exception & e) {
-        return {
-            /* .role = */ "assistant",
-            /* .content = */ input,
-            /* .tool_calls = */ {},
-        };
+        LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
+        common_chat_msg msg;
+        msg.role = "assistant";
+        msg.content = input;
+        return msg;
     }
 }
 
-static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
     return data;
 }
 
-common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_templates_apply_jinja(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    templates_params params;
+    params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
+    const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
+        ? *tmpls->template_tool_use
+        : *tmpls->template_default;
     const auto & src = tmpl.source();
     const auto & caps = tmpl.original_caps();
+    params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
+    params.add_generation_prompt = inputs.add_generation_prompt;
+    params.extract_reasoning = inputs.extract_reasoning;
+    params.tool_choice = inputs.tool_choice;
+    params.grammar = inputs.grammar;
+    if (!inputs.json_schema.empty()) {
+        params.json_schema = json::parse(inputs.json_schema);
+    }
 
-    if (inputs.tools.is_array()) {
-        if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
+    if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
+        LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
+        params.parallel_tool_calls = false;
+    } else {
+        params.parallel_tool_calls = inputs.parallel_tool_calls;
+    }
+
+    if (params.tools.is_array()) {
+        if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
             throw std::runtime_error("Cannot specify grammar with tools");
         }
         if (caps.supports_tool_calls && !caps.supports_tools) {
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
     }
 
     // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
-    if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
-        return common_chat_params_init_deepseek_r1(tmpl, inputs);
+    if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
+        return common_chat_params_init_deepseek_r1(tmpl, params);
     }
 
     // Command R7B: : use handler in all cases except json schema (thinking / tools).
-    if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
-        return common_chat_params_init_command_r7b(tmpl, inputs);
+    if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
+        return common_chat_params_init_command_r7b(tmpl, params);
     }
 
     // Use generic handler when mixing tools + JSON schema.
     // TODO: support that mix in handlers below.
-    if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
-        return common_chat_params_init_generic(tmpl, inputs);
+    if ((params.tools.is_array() && params.json_schema.is_object())) {
+        return common_chat_params_init_generic(tmpl, params);
     }
 
     // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
     if (src.find(">>>all") != std::string::npos) {
-        return common_chat_params_init_functionary_v3_2(tmpl, inputs);
+        return common_chat_params_init_functionary_v3_2(tmpl, params);
     }
 
     // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
     if (src.find(" functools[") != std::string::npos) {
-        return common_chat_params_init_firefunction_v2(tmpl, inputs);
+        return common_chat_params_init_firefunction_v2(tmpl, params);
     }
 
     // Plain handler (no tools)
-    if (inputs.tools.is_null() || inputs.tool_choice == "none") {
-        return common_chat_params_init_without_tools(tmpl, inputs);
+    if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
+        return common_chat_params_init_without_tools(tmpl, params);
     }
 
     // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
     if (src.find("<tool_call>") != std::string::npos) {
-        return common_chat_params_init_hermes_2_pro(tmpl, inputs);
+        return common_chat_params_init_hermes_2_pro(tmpl, params);
     }
 
     // Functionary v3.1 (w/ tools)
     if (src.find("<|start_header_id|>") != std::string::npos
         && src.find("<function=") != std::string::npos) {
-        return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
+        return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
     }
 
     // Llama 3.1, 3.2, 3.3 (w/ tools)
     if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
         auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
-        return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
+        return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
     }
 
     // Mistral Nemo (w/ tools)
     if (src.find("[TOOL_CALLS]") != std::string::npos) {
-        return common_chat_params_init_mistral_nemo(tmpl, inputs);
+        return common_chat_params_init_mistral_nemo(tmpl, params);
     }
 
     // Generic fallback
-    return common_chat_params_init_generic(tmpl, inputs);
+    return common_chat_params_init_generic(tmpl, params);
+}
+
+// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
+static common_chat_params common_chat_templates_apply_legacy(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    int alloc_size = 0;
+    std::vector<llama_chat_message> chat;
+    std::vector<std::string> contents;
+    for (const auto & msg : inputs.messages) {
+        auto content = msg.content;
+        for (const auto & part : msg.content_parts) {
+            if (part.type != "text") {
+                LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
+                continue;
+            }
+            if (!content.empty()) {
+                content += "\n";;
+            }
+            content += part.text;
+        }
+        contents.emplace_back(std::move(content));
+    }
+    for (size_t i = 0; i < contents.size(); ++i) {
+        const auto & msg = inputs.messages[i];
+        const auto & content = contents[i];
+        chat.push_back({msg.role.c_str(), content.c_str()});
+        alloc_size += (msg.role.size() + content.size()) * 1.25;
+    }
+
+    std::vector<char> buf(alloc_size);
+
+    // run the first time to get the total output length
+    const auto & src = tmpls->template_default->source();
+    int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+
+    // error: chat template is not supported
+    if (res < 0) {
+        // if the custom "tmpl" is not supported, we throw an error
+        // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+        throw std::runtime_error("this custom template is not supported");
+    }
+
+    // if it turns out that our buffer is too small, we resize it
+    if ((size_t) res > buf.size()) {
+        buf.resize(res);
+        res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+    }
+
+    common_chat_params params;
+    params.prompt = std::string(buf.data(), res);
+    if (!inputs.json_schema.empty()) {
+        params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
+    } else {
+        params.grammar = inputs.grammar;
+    }
+    return params;
+}
+
+common_chat_params common_chat_templates_apply(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    GGML_ASSERT(tmpls != nullptr);
+    return inputs.use_jinja
+        ? common_chat_templates_apply_jinja(tmpls, inputs)
+        : common_chat_templates_apply_legacy(tmpls, inputs);
 }
 
 static common_chat_msg common_chat_parse_content_only(const std::string & input) {
-    return {
-        /* .role = */ "assistant",
-        /* .content = */ input,
-        /* .tool_calls = */ {},
-    };
+    common_chat_msg msg;
+    msg.role = "assistant";
+    msg.content = input;
+    return msg;
 }
 
 common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
diff --git a/common/chat.h b/common/chat.h
new file mode 100644 (file)
index 0000000..e77bef8
--- /dev/null
@@ -0,0 +1,134 @@
+// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
+
+#pragma once
+
+#include "common.h"
+#include <string>
+#include <vector>
+
+struct common_chat_templates;
+
+struct common_chat_tool_call {
+    std::string name;
+    std::string arguments;
+    std::string id;
+};
+
+struct common_chat_msg_content_part {
+    std::string type;
+    std::string text;
+};
+
+struct common_chat_msg {
+    std::string role;
+    std::string content;
+    std::vector<common_chat_msg_content_part> content_parts = {};
+    std::vector<common_chat_tool_call> tool_calls = {};
+    std::string reasoning_content;
+    std::string tool_name;
+    std::string tool_call_id;
+};
+
+struct common_chat_tool {
+    std::string name;
+    std::string description;
+    std::string parameters;
+};
+
+enum common_chat_tool_choice {
+    COMMON_CHAT_TOOL_CHOICE_AUTO,
+    COMMON_CHAT_TOOL_CHOICE_REQUIRED,
+    COMMON_CHAT_TOOL_CHOICE_NONE,
+};
+
+enum common_chat_format {
+    COMMON_CHAT_FORMAT_CONTENT_ONLY,
+    COMMON_CHAT_FORMAT_GENERIC,
+    COMMON_CHAT_FORMAT_MISTRAL_NEMO,
+    COMMON_CHAT_FORMAT_LLAMA_3_X,
+    COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
+    COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
+    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
+    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+    COMMON_CHAT_FORMAT_HERMES_2_PRO,
+    COMMON_CHAT_FORMAT_COMMAND_R7B,
+    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
+
+    COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
+};
+
+struct common_chat_templates_inputs {
+    std::vector<common_chat_msg> messages;
+    std::string grammar;
+    std::string json_schema;
+    bool add_generation_prompt = true;
+    bool use_jinja = true;
+    // Parameters below only supported when use_jinja is true
+    std::vector<common_chat_tool> tools;
+    common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
+    bool parallel_tool_calls = false;
+    bool extract_reasoning     = true;
+};
+
+struct common_chat_params {
+    common_chat_format                  format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    std::string                         prompt;
+    std::string                         grammar;
+    bool                                grammar_lazy = false;
+    std::vector<common_grammar_trigger> grammar_triggers;
+    std::vector<std::string>            preserved_tokens;
+    std::vector<std::string>            additional_stops;
+};
+
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
+
+void common_chat_templates_free(struct common_chat_templates * tmpls);
+
+struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
+
+typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
+
+common_chat_templates_ptr common_chat_templates_init(
+                                    const struct llama_model * model,
+                                           const std::string & chat_template_override,
+                                           const std::string & bos_token_override = "",
+                                           const std::string & eos_token_override = "");
+
+bool         common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
+
+
+struct common_chat_params      common_chat_templates_apply(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string common_chat_format_single(
+        const struct common_chat_templates * tmpls,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass,
+        bool use_jinja);
+
+// Returns an example of formatted chat
+std::string common_chat_format_example(
+    const struct common_chat_templates * tmpls,
+    bool use_jinja);
+
+std::string               common_chat_format_name(common_chat_format format);
+common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
+
+// Parses a JSON array of messages in OpenAI's chat completion API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
+template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
+
+// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
+template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
diff --git a/common/chat.hpp b/common/chat.hpp
deleted file mode 100644 (file)
index ba1632f..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
-
-#pragma once
-
-#include "common.h"
-#include <json.hpp>
-#include <optional>
-#include <string>
-#include <vector>
-
-using json = nlohmann::ordered_json;
-
-struct common_chat_inputs {
-    json messages;
-    json tools;
-    json tool_choice;
-    json json_schema;
-    bool parallel_tool_calls;
-    bool stream;
-    std::string grammar;
-    bool add_generation_prompt = true;
-    bool extract_reasoning     = true;
-};
-
-enum common_chat_format {
-    COMMON_CHAT_FORMAT_CONTENT_ONLY,
-    COMMON_CHAT_FORMAT_GENERIC,
-    COMMON_CHAT_FORMAT_MISTRAL_NEMO,
-    COMMON_CHAT_FORMAT_LLAMA_3_X,
-    COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
-    COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
-    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
-    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
-    COMMON_CHAT_FORMAT_HERMES_2_PRO,
-    COMMON_CHAT_FORMAT_COMMAND_R7B,
-    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
-
-    COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
-};
-
-struct common_chat_params {
-    common_chat_format                  format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
-    json                                prompt;
-    std::string                         grammar;
-    bool                                grammar_lazy = false;
-    std::vector<common_grammar_trigger> grammar_triggers;
-    std::vector<std::string>            preserved_tokens;
-    std::vector<std::string>            additional_stops;
-};
-
-struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
-std::string               common_chat_format_name(common_chat_format format);
-common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);
index 8661e164ada6bdba72b9f5222444aa0bb3b7701b..d2b0d50e3ee3975e90ff12abcb2bedaef041ef21 100644 (file)
@@ -12,8 +12,6 @@
 #include "json.hpp"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
-#include "chat.hpp"
-#include "chat-template.hpp"
 
 #include <algorithm>
 #include <cinttypes>
@@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
     return text;
 }
 
-//
-// Chat template utils
-//
-
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
-    if (use_jinja) {
-        try {
-            auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
-            common_chat_inputs inputs;
-            inputs.messages = json::array({{
-                {"role", "user"},
-                {"content", "test"},
-            }});
-            common_chat_params_init(chat_template, inputs);
-            return true;
-        } catch (const std::exception & e) {
-            LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
-            return false;
-        }
-    }
-    llama_chat_message chat[] = {{"user", "test"}};
-    const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
-    return res >= 0;
-}
-
-std::string common_chat_apply_template(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & msgs,
-        bool add_ass,
-        bool use_jinja) {
-    if (use_jinja) {
-        auto messages = json::array();
-        for (const auto & msg : msgs) {
-            messages.push_back({{"role", msg.role}, {"content", msg.content}});
-        }
-        common_chat_inputs inputs;
-        inputs.messages = messages;
-        inputs.add_generation_prompt = add_ass;
-        return common_chat_params_init(tmpl, inputs).prompt;
-    }
-
-    int alloc_size = 0;
-    std::vector<llama_chat_message> chat;
-    for (const auto & msg : msgs) {
-        chat.push_back({msg.role.c_str(), msg.content.c_str()});
-        alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
-    }
-
-    std::vector<char> buf(alloc_size);
-
-    // run the first time to get the total output length
-    int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-
-    // error: chat template is not supported
-    if (res < 0) {
-        // if the custom "tmpl" is not supported, we throw an error
-        // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
-        throw std::runtime_error("this custom template is not supported");
-    }
-
-    // if it turns out that our buffer is too small, we resize it
-    if ((size_t) res > buf.size()) {
-        buf.resize(res);
-        res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-    }
-
-    std::string formatted_chat(buf.data(), res);
-    return formatted_chat;
-}
-
-std::string common_chat_format_single(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & past_msg,
-        const common_chat_msg & new_msg,
-        bool add_ass,
-        bool use_jinja) {
-    std::ostringstream ss;
-    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
-    std::vector<common_chat_msg> chat_new(past_msg);
-    // if the past_msg ends with a newline, we must preserve it in the formatted version
-    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
-        ss << "\n";
-    };
-    // format chat with new_msg
-    chat_new.push_back(new_msg);
-    auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
-    // get the diff part
-    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
-    return ss.str();
-}
-
-std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
-    std::vector<common_chat_msg> msgs = {
-        {"system",    "You are a helpful assistant", {}},
-        {"user",      "Hello", {}},
-        {"assistant", "Hi there", {}},
-        {"user",      "How are you?", {}},
-    };
-    return common_chat_apply_template(tmpl, msgs, true, use_jinja);
-}
-
-#define CHATML_TEMPLATE_SRC \
-    "{%- for message in messages -%}\n" \
-    "  {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
-    "{%- endfor -%}\n" \
-    "{%- if add_generation_prompt -%}\n" \
-    "  {{- '<|im_start|>assistant\n' -}}\n" \
-    "{%- endif -%}"
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
-{
-    std::string default_template_src;
-    std::string template_tool_use_src;
-
-    bool has_explicit_template = !chat_template_override.empty();
-    if (chat_template_override.empty()) {
-        auto str = llama_model_chat_template(model, /* name */ nullptr);
-        if (str) {
-            default_template_src = str;
-            has_explicit_template = true;
-        }
-        str = llama_model_chat_template(model, /* name */ "tool_use");
-        if (str) {
-            template_tool_use_src = str;
-            has_explicit_template = true;
-        }
-    } else {
-        default_template_src = chat_template_override;
-    }
-    if (default_template_src.empty() || default_template_src == "chatml") {
-        if (!template_tool_use_src.empty()) {
-            default_template_src = template_tool_use_src;
-        } else {
-            default_template_src = CHATML_TEMPLATE_SRC;
-        }
-    }
-    auto vocab = llama_model_get_vocab(model);
-    const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
-        if (token == LLAMA_TOKEN_NULL) {
-            if (default_template_src.find(jinja_variable_name) != std::string::npos
-                || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
-                LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
-            }
-            return std::string();
-        } else {
-            return common_token_to_piece(vocab, token, true);
-        }
-    };
-    auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
-    auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
-    try {
-        return {
-            has_explicit_template,
-            std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
-            template_tool_use_src.empty()
-                ? nullptr
-                : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
-        };
-    } catch (const std::exception & e) {
-        LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
-        return {
-            has_explicit_template,
-            std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
-            nullptr,
-        };
-    }
-}
-
 //
 // KV cache utils
 //
index 98b9a4464787a93cb476abed1f332909dcb7fe5a..10bcc10d51bb510633a0449ae48ea7f134ffe500 100644 (file)
@@ -616,62 +616,6 @@ std::string common_detokenize(
         const std::vector<llama_token> & tokens,
                                   bool   special = true);
 
-//
-// Chat template utils
-//
-
-struct common_tool_call {
-    std::string name;
-    std::string arguments;
-    std::string id;
-};
-
-// same with llama_chat_message, but uses std::string
-struct common_chat_msg {
-    std::string role;
-    std::string content;
-    std::vector<common_tool_call> tool_calls;
-    std::string reasoning_content = "";
-};
-
-// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
-
-namespace minja {
-    class chat_template;
-}
-
-typedef minja::chat_template common_chat_template;
-
-struct common_chat_templates {
-    bool has_explicit_template; // Model had builtin template or template overridde was specified.
-    std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
-    std::unique_ptr<common_chat_template> template_tool_use;
-};
-
-// CPP wrapper for llama_chat_apply_template
-// If the built-in template is not supported, we default to chatml
-// If the custom "tmpl" is not supported, we throw an error
-std::string common_chat_apply_template(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & chat,
-        bool add_ass,
-        bool use_jinja);
-
-// Format single message, while taking into account the position of that message in chat history
-std::string common_chat_format_single(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & past_msg,
-        const common_chat_msg & new_msg,
-        bool add_ass,
-        bool use_jinja);
-
-// Returns an example of formatted chat
-std::string common_chat_format_example(
-    const common_chat_template & tmpl, bool use_jinja);
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
-
 //
 // KV cache utils
 //
diff --git a/common/minja.hpp b/common/minja.hpp
deleted file mode 100644 (file)
index c58dd66..0000000
+++ /dev/null
@@ -1,2883 +0,0 @@
-/*
-    Copyright 2024 Google LLC
-
-    Use of this source code is governed by an MIT-style
-    license that can be found in the LICENSE file or at
-    https://opensource.org/licenses/MIT.
-*/
-// SPDX-License-Identifier: MIT
-#pragma once
-
-#include <iostream>
-#include <string>
-#include <vector>
-#include <regex>
-#include <memory>
-#include <stdexcept>
-#include <sstream>
-#include <unordered_set>
-#include <json.hpp>
-
-using json = nlohmann::ordered_json;
-
-namespace minja {
-
-class Context;
-
-struct Options {
-    bool trim_blocks;  // removes the first newline after a block
-    bool lstrip_blocks;  // removes leading whitespace on the line of the block
-    bool keep_trailing_newline;  // don't remove last newline
-};
-
-struct ArgumentsValue;
-
-inline std::string normalize_newlines(const std::string & s) {
-#ifdef _WIN32
-  static const std::regex nl_regex("\r\n");
-  return std::regex_replace(s, nl_regex, "\n");
-#else
-  return s;
-#endif
-}
-
-/* Values that behave roughly like in Python. */
-class Value : public std::enable_shared_from_this<Value> {
-public:
-  using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
-  using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
-
-private:
-  using ObjectType = nlohmann::ordered_map<json, Value>;  // Only contains primitive keys
-  using ArrayType = std::vector<Value>;
-
-  std::shared_ptr<ArrayType> array_;
-  std::shared_ptr<ObjectType> object_;
-  std::shared_ptr<CallableType> callable_;
-  json primitive_;
-
-  Value(const std::shared_ptr<ArrayType> & array) : array_(array) {}
-  Value(const std::shared_ptr<ObjectType> & object) : object_(object) {}
-  Value(const std::shared_ptr<CallableType> & callable) : object_(std::make_shared<ObjectType>()), callable_(callable) {}
-
-  /* Python-style string repr */
-  static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') {
-    if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump());
-    auto s = primitive.dump();
-    if (string_quote == '"' || s.find('\'') != std::string::npos) {
-      out << s;
-      return;
-    }
-    // Reuse json dump, just changing string quotes
-    out << string_quote;
-    for (size_t i = 1, n = s.size() - 1; i < n; ++i) {
-      if (s[i] == '\\' && s[i + 1] == '"') {
-        out << '"';
-        i++;
-      } else if (s[i] == string_quote) {
-        out << '\\' << string_quote;
-      } else {
-        out << s[i];
-      }
-    }
-    out << string_quote;
-  }
-  void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
-    auto print_indent = [&](int level) {
-      if (indent > 0) {
-          out << "\n";
-          for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
-      }
-    };
-    auto print_sub_sep = [&]() {
-      out << ',';
-      if (indent < 0) out << ' ';
-      else print_indent(level + 1);
-    };
-
-    auto string_quote = to_json ? '"' : '\'';
-
-    if (is_null()) out << "null";
-    else if (array_) {
-      out << "[";
-      print_indent(level + 1);
-      for (size_t i = 0; i < array_->size(); ++i) {
-        if (i) print_sub_sep();
-        (*array_)[i].dump(out, indent, level + 1, to_json);
-      }
-      print_indent(level);
-      out << "]";
-    } else if (object_) {
-      out << "{";
-      print_indent(level + 1);
-      for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) {
-        if (it != begin) print_sub_sep();
-        if (it->first.is_string()) {
-          dump_string(it->first, out, string_quote);
-        } else {
-          out << string_quote << it->first.dump() << string_quote;
-        }
-        out << ": ";
-        it->second.dump(out, indent, level + 1, to_json);
-      }
-      print_indent(level);
-      out << "}";
-    } else if (callable_) {
-      throw std::runtime_error("Cannot dump callable to JSON");
-    } else if (is_boolean() && !to_json) {
-      out << (this->to_bool() ? "True" : "False");
-    } else if (is_string() && !to_json) {
-      dump_string(primitive_, out, string_quote);
-    } else {
-      out << primitive_.dump();
-    }
-  }
-
-public:
-  Value() {}
-  Value(const bool& v) : primitive_(v) {}
-  Value(const int64_t & v) : primitive_(v) {}
-  Value(const double& v) : primitive_(v) {}
-  Value(const std::nullptr_t &) {}
-  Value(const std::string & v) : primitive_(v) {}
-  Value(const char * v) : primitive_(std::string(v)) {}
-
-  Value(const json & v) {
-    if (v.is_object()) {
-      auto object = std::make_shared<ObjectType>();
-      for (auto it = v.begin(); it != v.end(); ++it) {
-        (*object)[it.key()] = it.value();
-      }
-      object_ = std::move(object);
-    } else if (v.is_array()) {
-      auto array = std::make_shared<ArrayType>();
-      for (const auto& item : v) {
-        array->push_back(Value(item));
-      }
-      array_ = array;
-    } else {
-      primitive_ = v;
-    }
-  }
-
-  std::vector<Value> keys() {
-    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
-    std::vector<Value> res;
-    for (const auto& item : *object_) {
-      res.push_back(item.first);
-    }
-    return res;
-  }
-
-  size_t size() const {
-    if (is_object()) return object_->size();
-    if (is_array()) return array_->size();
-    if (is_string()) return primitive_.get<std::string>().length();
-    throw std::runtime_error("Value is not an array or object: " + dump());
-  }
-
-  static Value array(const std::vector<Value> values = {}) {
-    auto array = std::make_shared<ArrayType>();
-    for (const auto& item : values) {
-      array->push_back(item);
-    }
-    return Value(array);
-  }
-  static Value object(const std::shared_ptr<ObjectType> object = std::make_shared<ObjectType>()) {
-    return Value(object);
-  }
-  static Value callable(const CallableType & callable) {
-    return Value(std::make_shared<CallableType>(callable));
-  }
-
-  void insert(size_t index, const Value& v) {
-    if (!array_)
-      throw std::runtime_error("Value is not an array: " + dump());
-    array_->insert(array_->begin() + index, v);
-  }
-  void push_back(const Value& v) {
-    if (!array_)
-      throw std::runtime_error("Value is not an array: " + dump());
-    array_->push_back(v);
-  }
-  Value pop(const Value& index) {
-    if (is_array()) {
-      if (array_->empty())
-        throw std::runtime_error("pop from empty list");
-      if (index.is_null()) {
-        auto ret = array_->back();
-        array_->pop_back();
-        return ret;
-      } else if (!index.is_number_integer()) {
-        throw std::runtime_error("pop index must be an integer: " + index.dump());
-      } else {
-        auto i = index.get<int>();
-        if (i < 0 || i >= static_cast<int>(array_->size()))
-          throw std::runtime_error("pop index out of range: " + index.dump());
-        auto it = array_->begin() + (i < 0 ? array_->size() + i : i);
-        auto ret = *it;
-        array_->erase(it);
-        return ret;
-      }
-    } else if (is_object()) {
-      if (!index.is_hashable())
-        throw std::runtime_error("Unashable type: " + index.dump());
-      auto it = object_->find(index.primitive_);
-      if (it == object_->end())
-        throw std::runtime_error("Key not found: " + index.dump());
-      auto ret = it->second;
-      object_->erase(it);
-      return ret;
-    } else {
-      throw std::runtime_error("Value is not an array or object: " + dump());
-    }
-  }
-  Value get(const Value& key) {
-    if (array_) {
-      if (!key.is_number_integer()) {
-        return Value();
-      }
-      auto index = key.get<int>();
-      return array_->at(index < 0 ? array_->size() + index : index);
-    } else if (object_) {
-      if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
-      auto it = object_->find(key.primitive_);
-      if (it == object_->end()) return Value();
-      return it->second;
-    }
-    return Value();
-  }
-  void set(const Value& key, const Value& value) {
-    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
-    if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
-    (*object_)[key.primitive_] = value;
-  }
-  Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
-    if (!callable_) throw std::runtime_error("Value is not callable: " + dump());
-    return (*callable_)(context, args);
-  }
-
-  bool is_object() const { return !!object_; }
-  bool is_array() const { return !!array_; }
-  bool is_callable() const { return !!callable_; }
-  bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; }
-  bool is_boolean() const { return primitive_.is_boolean(); }
-  bool is_number_integer() const { return primitive_.is_number_integer(); }
-  bool is_number_float() const { return primitive_.is_number_float(); }
-  bool is_number() const { return primitive_.is_number(); }
-  bool is_string() const { return primitive_.is_string(); }
-  bool is_iterable() const { return is_array() || is_object() || is_string(); }
-
-  bool is_primitive() const { return !array_ && !object_ && !callable_; }
-  bool is_hashable() const { return is_primitive(); }
-
-  bool empty() const {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (is_string()) return primitive_.empty();
-    if (is_array()) return array_->empty();
-    if (is_object()) return object_->empty();
-    return false;
-  }
-
-  void for_each(const std::function<void(Value &)> & callback) const {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (array_) {
-      for (auto& item : *array_) {
-        callback(item);
-      }
-    } else if (object_) {
-      for (auto & item : *object_) {
-        Value key(item.first);
-        callback(key);
-      }
-    } else if (is_string()) {
-      for (char c : primitive_.get<std::string>()) {
-        auto val = Value(std::string(1, c));
-        callback(val);
-      }
-    } else {
-      throw std::runtime_error("Value is not iterable: " + dump());
-    }
-  }
-
-  bool to_bool() const {
-    if (is_null()) return false;
-    if (is_boolean()) return get<bool>();
-    if (is_number()) return get<double>() != 0;
-    if (is_string()) return !get<std::string>().empty();
-    if (is_array()) return !empty();
-    return true;
-  }
-
-  int64_t to_int() const {
-    if (is_null()) return 0;
-    if (is_boolean()) return get<bool>() ? 1 : 0;
-    if (is_number()) return static_cast<int64_t>(get<double>());
-    if (is_string()) {
-      try {
-        return std::stol(get<std::string>());
-      } catch (const std::exception &) {
-        return 0;
-      }
-    }
-    return 0;
-  }
-
-  bool operator<(const Value & other) const {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (is_number() && other.is_number()) return get<double>() < other.get<double>();
-    if (is_string() && other.is_string()) return get<std::string>() < other.get<std::string>();
-    throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump());
-  }
-  bool operator>=(const Value & other) const { return !(*this < other); }
-
-  bool operator>(const Value & other) const {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (is_number() && other.is_number()) return get<double>() > other.get<double>();
-    if (is_string() && other.is_string()) return get<std::string>() > other.get<std::string>();
-    throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump());
-  }
-  bool operator<=(const Value & other) const { return !(*this > other); }
-
-  bool operator==(const Value & other) const {
-    if (callable_ || other.callable_) {
-      if (callable_.get() != other.callable_.get()) return false;
-    }
-    if (array_) {
-      if (!other.array_) return false;
-      if (array_->size() != other.array_->size()) return false;
-      for (size_t i = 0; i < array_->size(); ++i) {
-        if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false;
-      }
-      return true;
-    } else if (object_) {
-      if (!other.object_) return false;
-      if (object_->size() != other.object_->size()) return false;
-      for (const auto& item : *object_) {
-        if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false;
-      }
-      return true;
-    } else {
-      return primitive_ == other.primitive_;
-    }
-  }
-  bool operator!=(const Value & other) const { return !(*this == other); }
-
-  bool contains(const char * key) const { return contains(std::string(key)); }
-  bool contains(const std::string & key) const {
-    if (array_) {
-      return false;
-    } else if (object_) {
-      return object_->find(key) != object_->end();
-    } else {
-      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
-    }
-  }
-  bool contains(const Value & value) const {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (array_) {
-      for (const auto& item : *array_) {
-        if (item.to_bool() && item == value) return true;
-      }
-      return false;
-    } else if (object_) {
-      if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
-      return object_->find(value.primitive_) != object_->end();
-    } else {
-      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
-    }
-  }
-  void erase(size_t index) {
-    if (!array_) throw std::runtime_error("Value is not an array: " + dump());
-    array_->erase(array_->begin() + index);
-  }
-  void erase(const std::string & key) {
-    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
-    object_->erase(key);
-  }
-  const Value& at(const Value & index) const {
-    return const_cast<Value*>(this)->at(index);
-  }
-  Value& at(const Value & index) {
-    if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
-    if (is_array()) return array_->at(index.get<int>());
-    if (is_object()) return object_->at(index.primitive_);
-    throw std::runtime_error("Value is not an array or object: " + dump());
-  }
-  const Value& at(size_t index) const {
-    return const_cast<Value*>(this)->at(index);
-  }
-  Value& at(size_t index) {
-    if (is_null())
-      throw std::runtime_error("Undefined value or reference");
-    if (is_array()) return array_->at(index);
-    if (is_object()) return object_->at(index);
-    throw std::runtime_error("Value is not an array or object: " + dump());
-  }
-
-  template <typename T>
-  T get(const std::string & key, T default_value) const {
-    if (!contains(key)) return default_value;
-    return at(key).get<T>();
-  }
-
-  template <typename T>
-  T get() const {
-    if (is_primitive()) return primitive_.get<T>();
-    throw std::runtime_error("get<T> not defined for this value type: " + dump());
-  }
-
-  std::string dump(int indent=-1, bool to_json=false) const {
-    std::ostringstream out;
-    dump(out, indent, 0, to_json);
-    return out.str();
-  }
-
-  Value operator-() const {
-      if (is_number_integer())
-        return -get<int64_t>();
-      else
-        return -get<double>();
-  }
-  std::string to_str() const {
-    if (is_string()) return get<std::string>();
-    if (is_number_integer()) return std::to_string(get<int64_t>());
-    if (is_number_float()) return std::to_string(get<double>());
-    if (is_boolean()) return get<bool>() ? "True" : "False";
-    if (is_null()) return "None";
-    return dump();
-  }
-  Value operator+(const Value& rhs) const {
-      if (is_string() || rhs.is_string()) {
-        return to_str() + rhs.to_str();
-      } else if (is_number_integer() && rhs.is_number_integer()) {
-        return get<int64_t>() + rhs.get<int64_t>();
-      } else if (is_array() && rhs.is_array()) {
-        auto res = Value::array();
-        for (const auto& item : *array_) res.push_back(item);
-        for (const auto& item : *rhs.array_) res.push_back(item);
-        return res;
-      } else {
-        return get<double>() + rhs.get<double>();
-      }
-  }
-  Value operator-(const Value& rhs) const {
-      if (is_number_integer() && rhs.is_number_integer())
-        return get<int64_t>() - rhs.get<int64_t>();
-      else
-        return get<double>() - rhs.get<double>();
-  }
-  Value operator*(const Value& rhs) const {
-      if (is_string() && rhs.is_number_integer()) {
-        std::ostringstream out;
-        for (int64_t i = 0, n = rhs.get<int64_t>(); i < n; ++i) {
-          out << to_str();
-        }
-        return out.str();
-      }
-      else if (is_number_integer() && rhs.is_number_integer())
-        return get<int64_t>() * rhs.get<int64_t>();
-      else
-        return get<double>() * rhs.get<double>();
-  }
-  Value operator/(const Value& rhs) const {
-      if (is_number_integer() && rhs.is_number_integer())
-        return get<int64_t>() / rhs.get<int64_t>();
-      else
-        return get<double>() / rhs.get<double>();
-  }
-  Value operator%(const Value& rhs) const {
-    return get<int64_t>() % rhs.get<int64_t>();
-  }
-};
-
-struct ArgumentsValue {
-  std::vector<Value> args;
-  std::vector<std::pair<std::string, Value>> kwargs;
-
-  bool has_named(const std::string & name) {
-    for (const auto & p : kwargs) {
-      if (p.first == name) return true;
-    }
-    return false;
-  }
-
-  Value get_named(const std::string & name) {
-    for (const auto & [key, value] : kwargs) {
-      if (key == name) return value;
-    }
-    return Value();
-  }
-
-  bool empty() {
-    return args.empty() && kwargs.empty();
-  }
-
-  void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) {
-    if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
-      std::ostringstream out;
-      out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments";
-      throw std::runtime_error(out.str());
-    }
-  }
-};
-
-template <>
-inline json Value::get<json>() const {
-  if (is_primitive()) return primitive_;
-  if (is_null()) return json();
-  if (array_) {
-    std::vector<json> res;
-    for (const auto& item : *array_) {
-      res.push_back(item.get<json>());
-    }
-    return res;
-  }
-  if (object_) {
-    json res = json::object();
-    for (const auto& [key, value] : *object_) {
-      if (key.is_string()) {
-        res[key.get<std::string>()] = value.get<json>();
-      } else if (key.is_primitive()) {
-        res[key.dump()] = value.get<json>();
-      } else {
-        throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
-      }
-    }
-    if (is_callable()) {
-      res["__callable__"] = true;
-    }
-    return res;
-  }
-  throw std::runtime_error("get<json> not defined for this value type: " + dump());
-}
-
-} // namespace minja
-
-namespace std {
-  template <>
-  struct hash<minja::Value> {
-    size_t operator()(const minja::Value & v) const {
-      if (!v.is_hashable())
-        throw std::runtime_error("Unsupported type for hashing: " + v.dump());
-      return std::hash<json>()(v.get<json>());
-    }
-  };
-} // namespace std
-
-namespace minja {
-
-static std::string error_location_suffix(const std::string & source, size_t pos) {
-  auto get_line = [&](size_t line) {
-    auto start = source.begin();
-    for (size_t i = 1; i < line; ++i) {
-      start = std::find(start, source.end(), '\n') + 1;
-    }
-    auto end = std::find(start, source.end(), '\n');
-    return std::string(start, end);
-  };
-  auto start = source.begin();
-  auto end = source.end();
-  auto it = start + pos;
-  auto line = std::count(start, it, '\n') + 1;
-  auto max_line = std::count(start, end, '\n') + 1;
-  auto col = pos - std::string(start, it).rfind('\n');
-  std::ostringstream out;
-  out << " at row " << line << ", column " << col << ":\n";
-  if (line > 1) out << get_line(line - 1) << "\n";
-  out << get_line(line) << "\n";
-  out << std::string(col - 1, ' ') << "^\n";
-  if (line < max_line) out << get_line(line + 1) << "\n";
-
-  return out.str();
-}
-
-class Context : public std::enable_shared_from_this<Context> {
-  protected:
-    Value values_;
-    std::shared_ptr<Context> parent_;
-  public:
-    Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
-        if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
-    }
-    virtual ~Context() {}
-
-    static std::shared_ptr<Context> builtins();
-    static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
-
-    std::vector<Value> keys() {
-        return values_.keys();
-    }
-    virtual Value get(const Value & key) {
-        if (values_.contains(key)) return values_.at(key);
-        if (parent_) return parent_->get(key);
-        return Value();
-    }
-    virtual Value & at(const Value & key) {
-        if (values_.contains(key)) return values_.at(key);
-        if (parent_) return parent_->at(key);
-        throw std::runtime_error("Undefined variable: " + key.dump());
-    }
-    virtual bool contains(const Value & key) {
-        if (values_.contains(key)) return true;
-        if (parent_) return parent_->contains(key);
-        return false;
-    }
-    virtual void set(const Value & key, const Value & value) {
-        values_.set(key, value);
-    }
-};
-
-struct Location {
-    std::shared_ptr<std::string> source;
-    size_t pos;
-};
-
-class Expression {
-protected:
-    virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
-public:
-    using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
-
-    Location location;
-
-    Expression(const Location & location) : location(location) {}
-    virtual ~Expression() = default;
-
-    Value evaluate(const std::shared_ptr<Context> & context) const {
-        try {
-            return do_evaluate(context);
-        } catch (const std::exception & e) {
-            std::ostringstream out;
-            out << e.what();
-            if (location.source) out << error_location_suffix(*location.source, location.pos);
-            throw std::runtime_error(out.str());
-        }
-    }
-};
-
-class VariableExpr : public Expression {
-    std::string name;
-public:
-    VariableExpr(const Location & location, const std::string& n)
-      : Expression(location), name(n) {}
-    std::string get_name() const { return name; }
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!context->contains(name)) {
-            return Value();
-        }
-        return context->at(name);
-    }
-};
-
-static void destructuring_assign(const std::vector<std::string> & var_names, const std::shared_ptr<Context> & context, Value& item) {
-  if (var_names.size() == 1) {
-      Value name(var_names[0]);
-      context->set(name, item);
-  } else {
-      if (!item.is_array() || item.size() != var_names.size()) {
-          throw std::runtime_error("Mismatched number of variables and items in destructuring assignment");
-      }
-      for (size_t i = 0; i < var_names.size(); ++i) {
-          context->set(var_names[i], item.at(i));
-      }
-  }
-}
-
-enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
-
-class TemplateToken {
-public:
-    enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
-
-    static std::string typeToString(Type t) {
-        switch (t) {
-            case Type::Text: return "text";
-            case Type::Expression: return "expression";
-            case Type::If: return "if";
-            case Type::Else: return "else";
-            case Type::Elif: return "elif";
-            case Type::EndIf: return "endif";
-            case Type::For: return "for";
-            case Type::EndFor: return "endfor";
-            case Type::Set: return "set";
-            case Type::EndSet: return "endset";
-            case Type::Comment: return "comment";
-            case Type::Macro: return "macro";
-            case Type::EndMacro: return "endmacro";
-            case Type::Filter: return "filter";
-            case Type::EndFilter: return "endfilter";
-            case Type::Generation: return "generation";
-            case Type::EndGeneration: return "endgeneration";
-            case Type::Break: return "break";
-            case Type::Continue: return "continue";
-        }
-        return "Unknown";
-    }
-
-    TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {}
-    virtual ~TemplateToken() = default;
-
-    Type type;
-    Location location;
-    SpaceHandling pre_space = SpaceHandling::Keep;
-    SpaceHandling post_space = SpaceHandling::Keep;
-};
-
-struct TextTemplateToken : public TemplateToken {
-    std::string text;
-    TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {}
-};
-
-struct ExpressionTemplateToken : public TemplateToken {
-    std::shared_ptr<Expression> expr;
-    ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
-};
-
-struct IfTemplateToken : public TemplateToken {
-    std::shared_ptr<Expression> condition;
-    IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
-};
-
-struct ElifTemplateToken : public TemplateToken {
-    std::shared_ptr<Expression> condition;
-    ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
-};
-
-struct ElseTemplateToken : public TemplateToken {
-    ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {}
-};
-
-struct EndIfTemplateToken : public TemplateToken {
-    EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
-};
-
-struct MacroTemplateToken : public TemplateToken {
-    std::shared_ptr<VariableExpr> name;
-    Expression::Parameters params;
-    MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
-      : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
-};
-
-struct EndMacroTemplateToken : public TemplateToken {
-    EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {}
-};
-
-struct FilterTemplateToken : public TemplateToken {
-    std::shared_ptr<Expression> filter;
-    FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
-      : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {}
-};
-
-struct EndFilterTemplateToken : public TemplateToken {
-    EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {}
-};
-
-struct ForTemplateToken : public TemplateToken {
-    std::vector<std::string> var_names;
-    std::shared_ptr<Expression> iterable;
-    std::shared_ptr<Expression> condition;
-    bool recursive;
-    ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
-      std::shared_ptr<Expression> && c, bool r)
-      : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
-};
-
-struct EndForTemplateToken : public TemplateToken {
-    EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
-};
-
-struct GenerationTemplateToken : public TemplateToken {
-    GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
-};
-
-struct EndGenerationTemplateToken : public TemplateToken {
-    EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
-};
-
-struct SetTemplateToken : public TemplateToken {
-    std::string ns;
-    std::vector<std::string> var_names;
-    std::shared_ptr<Expression> value;
-    SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
-      : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
-};
-
-struct EndSetTemplateToken : public TemplateToken {
-    EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {}
-};
-
-struct CommentTemplateToken : public TemplateToken {
-    std::string text;
-    CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
-};
-
-enum class LoopControlType { Break, Continue };
-
-class LoopControlException : public std::runtime_error {
-public:
-    LoopControlType control_type;
-    LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
-    LoopControlException(LoopControlType control_type)
-      : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
-        control_type(control_type) {}
-};
-
-struct LoopControlTemplateToken : public TemplateToken {
-    LoopControlType control_type;
-    LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
-};
-
-class TemplateNode {
-    Location location_;
-protected:
-    virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
-
-public:
-    TemplateNode(const Location & location) : location_(location) {}
-    void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
-        try {
-            do_render(out, context);
-        } catch (const LoopControlException & e) {
-            // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
-            std::ostringstream err;
-            err << e.what();
-            if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
-            throw LoopControlException(err.str(), e.control_type);
-        } catch (const std::exception & e) {
-            std::ostringstream err;
-            err << e.what();
-            if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
-            throw std::runtime_error(err.str());
-        }
-    }
-    const Location & location() const { return location_; }
-    virtual ~TemplateNode() = default;
-    std::string render(const std::shared_ptr<Context> & context) const {
-        std::ostringstream out;
-        render(out, context);
-        return out.str();
-    }
-};
-
-class SequenceNode : public TemplateNode {
-    std::vector<std::shared_ptr<TemplateNode>> children;
-public:
-    SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
-      : TemplateNode(location), children(std::move(c)) {}
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
-        for (const auto& child : children) child->render(out, context);
-    }
-};
-
-class TextNode : public TemplateNode {
-    std::string text;
-public:
-    TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {}
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
-      out << text;
-    }
-};
-
-class ExpressionNode : public TemplateNode {
-    std::shared_ptr<Expression> expr;
-public:
-    ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
-      if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
-      auto result = expr->evaluate(context);
-      if (result.is_string()) {
-          out << result.get<std::string>();
-      } else if (result.is_boolean()) {
-          out << (result.get<bool>() ? "True" : "False");
-      } else if (!result.is_null()) {
-          out << result.dump();
-      }
-  }
-};
-
-class IfNode : public TemplateNode {
-    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
-public:
-    IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
-        : TemplateNode(location), cascade(std::move(c)) {}
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
-      for (const auto& branch : cascade) {
-          auto enter_branch = true;
-          if (branch.first) {
-            enter_branch = branch.first->evaluate(context).to_bool();
-          }
-          if (enter_branch) {
-            if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
-              branch.second->render(out, context);
-              return;
-          }
-      }
-    }
-};
-
-class LoopControlNode : public TemplateNode {
-    LoopControlType control_type_;
-  public:
-    LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
-    void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
-      throw LoopControlException(control_type_);
-    }
-};
-
-class ForNode : public TemplateNode {
-    std::vector<std::string> var_names;
-    std::shared_ptr<Expression> iterable;
-    std::shared_ptr<Expression> condition;
-    std::shared_ptr<TemplateNode> body;
-    bool recursive;
-    std::shared_ptr<TemplateNode> else_body;
-public:
-    ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
-      std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
-            : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
-
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
-      // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
-      if (!iterable) throw std::runtime_error("ForNode.iterable is null");
-      if (!body) throw std::runtime_error("ForNode.body is null");
-
-      auto iterable_value = iterable->evaluate(context);
-      Value::CallableType loop_function;
-
-      std::function<void(Value&)> visit = [&](Value& iter) {
-          auto filtered_items = Value::array();
-          if (!iter.is_null()) {
-            if (!iterable_value.is_iterable()) {
-              throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump());
-            }
-            iterable_value.for_each([&](Value & item) {
-                destructuring_assign(var_names, context, item);
-                if (!condition || condition->evaluate(context).to_bool()) {
-                  filtered_items.push_back(item);
-                }
-            });
-          }
-          if (filtered_items.empty()) {
-            if (else_body) {
-              else_body->render(out, context);
-            }
-          } else {
-              auto loop = recursive ? Value::callable(loop_function) : Value::object();
-              loop.set("length", (int64_t) filtered_items.size());
-
-              size_t cycle_index = 0;
-              loop.set("cycle", Value::callable([&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
-                  if (args.args.empty() || !args.kwargs.empty()) {
-                      throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg");
-                  }
-                  auto item = args.args[cycle_index];
-                  cycle_index = (cycle_index + 1) % args.args.size();
-                  return item;
-              }));
-              auto loop_context = Context::make(Value::object(), context);
-              loop_context->set("loop", loop);
-              for (size_t i = 0, n = filtered_items.size(); i < n; ++i) {
-                  auto & item = filtered_items.at(i);
-                  destructuring_assign(var_names, loop_context, item);
-                  loop.set("index", (int64_t) i + 1);
-                  loop.set("index0", (int64_t) i);
-                  loop.set("revindex", (int64_t) (n - i));
-                  loop.set("revindex0", (int64_t) (n - i - 1));
-                  loop.set("length", (int64_t) n);
-                  loop.set("first", i == 0);
-                  loop.set("last", i == (n - 1));
-                  loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
-                  loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
-                  try {
-                      body->render(out, loop_context);
-                  } catch (const LoopControlException & e) {
-                      if (e.control_type == LoopControlType::Break) break;
-                      if (e.control_type == LoopControlType::Continue) continue;
-                  }
-              }
-          }
-      };
-
-      if (recursive) {
-        loop_function = [&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
-            if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) {
-                throw std::runtime_error("loop() expects exactly 1 positional iterable argument");
-            }
-            auto & items = args.args[0];
-            visit(items);
-            return Value();
-        };
-      }
-
-      visit(iterable_value);
-  }
-};
-
-class MacroNode : public TemplateNode {
-    std::shared_ptr<VariableExpr> name;
-    Expression::Parameters params;
-    std::shared_ptr<TemplateNode> body;
-    std::unordered_map<std::string, size_t> named_param_positions;
-public:
-    MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
-        : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
-        for (size_t i = 0; i < params.size(); ++i) {
-          const auto & name = params[i].first;
-          if (!name.empty()) {
-            named_param_positions[name] = i;
-          }
-        }
-    }
-    void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
-        if (!name) throw std::runtime_error("MacroNode.name is null");
-        if (!body) throw std::runtime_error("MacroNode.body is null");
-        auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-            auto call_context = macro_context;
-            std::vector<bool> param_set(params.size(), false);
-            for (size_t i = 0, n = args.args.size(); i < n; i++) {
-                auto & arg = args.args[i];
-                if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
-                param_set[i] = true;
-                auto & param_name = params[i].first;
-                call_context->set(param_name, arg);
-            }
-            for (auto & [arg_name, value] : args.kwargs) {
-                auto it = named_param_positions.find(arg_name);
-                if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
-
-                call_context->set(arg_name, value);
-                param_set[it->second] = true;
-            }
-            // Set default values for parameters that were not passed
-            for (size_t i = 0, n = params.size(); i < n; i++) {
-                if (!param_set[i] && params[i].second != nullptr) {
-                    auto val = params[i].second->evaluate(context);
-                    call_context->set(params[i].first, val);
-                }
-            }
-            return body->render(call_context);
-        });
-        macro_context->set(name->get_name(), callable);
-    }
-};
-
-class FilterNode : public TemplateNode {
-    std::shared_ptr<Expression> filter;
-    std::shared_ptr<TemplateNode> body;
-
-public:
-    FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
-        : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {}
-
-    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
-        if (!filter) throw std::runtime_error("FilterNode.filter is null");
-        if (!body) throw std::runtime_error("FilterNode.body is null");
-        auto filter_value = filter->evaluate(context);
-        if (!filter_value.is_callable()) {
-            throw std::runtime_error("Filter must be a callable: " + filter_value.dump());
-        }
-        std::string rendered_body = body->render(context);
-
-        ArgumentsValue filter_args = {{Value(rendered_body)}, {}};
-        auto result = filter_value.call(context, filter_args);
-        out << result.to_str();
-    }
-};
-
-class SetNode : public TemplateNode {
-    std::string ns;
-    std::vector<std::string> var_names;
-    std::shared_ptr<Expression> value;
-public:
-    SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
-        : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
-    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
-      if (!value) throw std::runtime_error("SetNode.value is null");
-      if (!ns.empty()) {
-        if (var_names.size() != 1) {
-          throw std::runtime_error("Namespaced set only supports a single variable name");
-        }
-        auto & name = var_names[0];
-        auto ns_value = context->get(ns);
-        if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
-        ns_value.set(name, this->value->evaluate(context));
-      } else {
-        auto val = value->evaluate(context);
-        destructuring_assign(var_names, context, val);
-      }
-    }
-};
-
-class SetTemplateNode : public TemplateNode {
-    std::string name;
-    std::shared_ptr<TemplateNode> template_value;
-public:
-    SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
-        : TemplateNode(location), name(name), template_value(std::move(tv)) {}
-    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
-      if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
-      Value value { template_value->render(context) };
-      context->set(name, value);
-    }
-};
-
-class IfExpr : public Expression {
-    std::shared_ptr<Expression> condition;
-    std::shared_ptr<Expression> then_expr;
-    std::shared_ptr<Expression> else_expr;
-public:
-    IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
-        : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-      if (!condition) throw std::runtime_error("IfExpr.condition is null");
-      if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
-      if (condition->evaluate(context).to_bool()) {
-        return then_expr->evaluate(context);
-      }
-      if (else_expr) {
-        return else_expr->evaluate(context);
-      }
-      return nullptr;
-    }
-};
-
-class LiteralExpr : public Expression {
-    Value value;
-public:
-    LiteralExpr(const Location & location, const Value& v)
-      : Expression(location), value(v) {}
-    Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
-};
-
-class ArrayExpr : public Expression {
-    std::vector<std::shared_ptr<Expression>> elements;
-public:
-    ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
-      : Expression(location), elements(std::move(e)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        auto result = Value::array();
-        for (const auto& e : elements) {
-            if (!e) throw std::runtime_error("Array element is null");
-            result.push_back(e->evaluate(context));
-        }
-        return result;
-    }
-};
-
-class DictExpr : public Expression {
-    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
-public:
-    DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
-      : Expression(location), elements(std::move(e)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        auto result = Value::object();
-        for (const auto& [key, value] : elements) {
-            if (!key) throw std::runtime_error("Dict key is null");
-            if (!value) throw std::runtime_error("Dict value is null");
-            result.set(key->evaluate(context), value->evaluate(context));
-        }
-        return result;
-    }
-};
-
-class SliceExpr : public Expression {
-public:
-    std::shared_ptr<Expression> start, end;
-    SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
-      : Expression(location), start(std::move(s)), end(std::move(e)) {}
-    Value do_evaluate(const std::shared_ptr<Context> &) const override {
-        throw std::runtime_error("SliceExpr not implemented");
-    }
-};
-
-class SubscriptExpr : public Expression {
-    std::shared_ptr<Expression> base;
-    std::shared_ptr<Expression> index;
-public:
-    SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
-        : Expression(location), base(std::move(b)), index(std::move(i)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!base) throw std::runtime_error("SubscriptExpr.base is null");
-        if (!index) throw std::runtime_error("SubscriptExpr.index is null");
-        auto target_value = base->evaluate(context);
-        if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
-          auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
-          auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
-          if (target_value.is_string()) {
-            std::string s = target_value.get<std::string>();
-            if (start < 0) start = s.size() + start;
-            if (end < 0) end = s.size() + end;
-            return s.substr(start, end - start);
-          } else if (target_value.is_array()) {
-            if (start < 0) start = target_value.size() + start;
-            if (end < 0) end = target_value.size() + end;
-            auto result = Value::array();
-            for (auto i = start; i < end; ++i) {
-              result.push_back(target_value.at(i));
-            }
-            return result;
-          } else {
-            throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings");
-          }
-        } else {
-          auto index_value = index->evaluate(context);
-          if (target_value.is_null()) {
-            if (auto t = dynamic_cast<VariableExpr*>(base.get())) {
-              throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined"));
-            }
-            throw std::runtime_error("Trying to access property '" +  index_value.dump() + "' on null!");
-          }
-          return target_value.get(index_value);
-        }
-    }
-};
-
-class UnaryOpExpr : public Expression {
-public:
-    enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
-    std::shared_ptr<Expression> expr;
-    Op op;
-    UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
-      : Expression(location), expr(std::move(e)), op(o) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
-        auto e = expr->evaluate(context);
-        switch (op) {
-            case Op::Plus: return e;
-            case Op::Minus: return -e;
-            case Op::LogicalNot: return !e.to_bool();
-            case Op::Expansion:
-            case Op::ExpansionDict:
-                throw std::runtime_error("Expansion operator is only supported in function calls and collections");
-
-        }
-        throw std::runtime_error("Unknown unary operator");
-    }
-};
-
-class BinaryOpExpr : public Expression {
-public:
-    enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
-private:
-    std::shared_ptr<Expression> left;
-    std::shared_ptr<Expression> right;
-    Op op;
-public:
-    BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
-        : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
-        if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
-        auto l = left->evaluate(context);
-
-        auto do_eval = [&](const Value & l) -> Value {
-          if (op == Op::Is || op == Op::IsNot) {
-            auto t = dynamic_cast<VariableExpr*>(right.get());
-            if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable");
-
-            auto eval = [&]() {
-              const auto & name = t->get_name();
-              if (name == "none") return l.is_null();
-              if (name == "boolean") return l.is_boolean();
-              if (name == "integer") return l.is_number_integer();
-              if (name == "float") return l.is_number_float();
-              if (name == "number") return l.is_number();
-              if (name == "string") return l.is_string();
-              if (name == "mapping") return l.is_object();
-              if (name == "iterable") return l.is_iterable();
-              if (name == "sequence") return l.is_array();
-              if (name == "defined") return !l.is_null();
-              throw std::runtime_error("Unknown type for 'is' operator: " + name);
-            };
-            auto value = eval();
-            return Value(op == Op::Is ? value : !value);
-          }
-
-          if (op == Op::And) {
-            if (!l.to_bool()) return Value(false);
-            return right->evaluate(context).to_bool();
-          } else if (op == Op::Or) {
-            if (l.to_bool()) return l;
-            return right->evaluate(context);
-          }
-
-          auto r = right->evaluate(context);
-          switch (op) {
-              case Op::StrConcat: return l.to_str() + r.to_str();
-              case Op::Add:       return l + r;
-              case Op::Sub:       return l - r;
-              case Op::Mul:       return l * r;
-              case Op::Div:       return l / r;
-              case Op::MulMul:    return std::pow(l.get<double>(), r.get<double>());
-              case Op::DivDiv:    return l.get<int64_t>() / r.get<int64_t>();
-              case Op::Mod:       return l.get<int64_t>() % r.get<int64_t>();
-              case Op::Eq:        return l == r;
-              case Op::Ne:        return l != r;
-              case Op::Lt:        return l < r;
-              case Op::Gt:        return l > r;
-              case Op::Le:        return l <= r;
-              case Op::Ge:        return l >= r;
-              case Op::In:        return (r.is_array() || r.is_object()) && r.contains(l);
-              case Op::NotIn:     return !(r.is_array() && r.contains(l));
-              default:            break;
-          }
-          throw std::runtime_error("Unknown binary operator");
-        };
-
-        if (l.is_callable()) {
-          return Value::callable([l, do_eval](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-            auto ll = l.call(context, args);
-            return do_eval(ll); //args[0].second);
-          });
-        } else {
-          return do_eval(l);
-        }
-    }
-};
-
-struct ArgumentsExpression {
-    std::vector<std::shared_ptr<Expression>> args;
-    std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
-
-    ArgumentsValue evaluate(const std::shared_ptr<Context> & context) const {
-        ArgumentsValue vargs;
-        for (const auto& arg : this->args) {
-            if (auto un_expr = std::dynamic_pointer_cast<UnaryOpExpr>(arg)) {
-                if (un_expr->op == UnaryOpExpr::Op::Expansion) {
-                    auto array = un_expr->expr->evaluate(context);
-                    if (!array.is_array()) {
-                        throw std::runtime_error("Expansion operator only supported on arrays");
-                    }
-                    array.for_each([&](Value & value) {
-                        vargs.args.push_back(value);
-                    });
-                    continue;
-                } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) {
-                    auto dict = un_expr->expr->evaluate(context);
-                    if (!dict.is_object()) {
-                        throw std::runtime_error("ExpansionDict operator only supported on objects");
-                    }
-                    dict.for_each([&](const Value & key) {
-                        vargs.kwargs.push_back({key.get<std::string>(), dict.at(key)});
-                    });
-                    continue;
-                }
-            }
-            vargs.args.push_back(arg->evaluate(context));
-        }
-        for (const auto& [name, value] : this->kwargs) {
-            vargs.kwargs.push_back({name, value->evaluate(context)});
-        }
-        return vargs;
-    }
-};
-
-static std::string strip(const std::string & s) {
-  auto start = s.find_first_not_of(" \t\n\r");
-  if (start == std::string::npos) return "";
-  auto end = s.find_last_not_of(" \t\n\r");
-  return s.substr(start, end - start + 1);
-}
-
-static std::string capitalize(const std::string & s) {
-  if (s.empty()) return s;
-  auto result = s;
-  result[0] = std::toupper(result[0]);
-  return result;
-}
-
-static std::string html_escape(const std::string & s) {
-  std::string result;
-  result.reserve(s.size());
-  for (const auto & c : s) {
-    switch (c) {
-      case '&': result += "&amp;"; break;
-      case '<': result += "&lt;"; break;
-      case '>': result += "&gt;"; break;
-      case '"': result += "&#34;"; break;
-      case '\'': result += "&apos;"; break;
-      default: result += c; break;
-    }
-  }
-  return result;
-}
-
-class MethodCallExpr : public Expression {
-    std::shared_ptr<Expression> object;
-    std::shared_ptr<VariableExpr> method;
-    ArgumentsExpression args;
-public:
-    MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
-        : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!object) throw std::runtime_error("MethodCallExpr.object is null");
-        if (!method) throw std::runtime_error("MethodCallExpr.method is null");
-        auto obj = object->evaluate(context);
-        auto vargs = args.evaluate(context);
-        if (obj.is_null()) {
-          throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null");
-        }
-        if (obj.is_array()) {
-          if (method->get_name() == "append") {
-              vargs.expectArgs("append method", {1, 1}, {0, 0});
-              obj.push_back(vargs.args[0]);
-              return Value();
-          } else if (method->get_name() == "pop") {
-              vargs.expectArgs("pop method", {0, 1}, {0, 0});
-              return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]);
-          } else if (method->get_name() == "insert") {
-              vargs.expectArgs("insert method", {2, 2}, {0, 0});
-              auto index = vargs.args[0].get<int64_t>();
-              if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method");
-              obj.insert(index, vargs.args[1]);
-              return Value();
-          }
-        } else if (obj.is_object()) {
-          if (method->get_name() == "items") {
-            vargs.expectArgs("items method", {0, 0}, {0, 0});
-            auto result = Value::array();
-            for (const auto& key : obj.keys()) {
-              result.push_back(Value::array({key, obj.at(key)}));
-            }
-            return result;
-          } else if (method->get_name() == "pop") {
-            vargs.expectArgs("pop method", {1, 1}, {0, 0});
-            return obj.pop(vargs.args[0]);
-          } else if (method->get_name() == "get") {
-            vargs.expectArgs("get method", {1, 2}, {0, 0});
-            auto key = vargs.args[0];
-            if (vargs.args.size() == 1) {
-              return obj.contains(key) ? obj.at(key) : Value();
-            } else {
-              return obj.contains(key) ? obj.at(key) : vargs.args[1];
-            }
-          } else if (obj.contains(method->get_name())) {
-            auto callable = obj.at(method->get_name());
-            if (!callable.is_callable()) {
-              throw std::runtime_error("Property '" + method->get_name() + "' is not callable");
-            }
-            return callable.call(context, vargs);
-          }
-        } else if (obj.is_string()) {
-          auto str = obj.get<std::string>();
-          if (method->get_name() == "strip") {
-            vargs.expectArgs("strip method", {0, 0}, {0, 0});
-            return Value(strip(str));
-          } else if (method->get_name() == "capitalize") {
-            vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
-            return Value(capitalize(str));
-          } else if (method->get_name() == "endswith") {
-            vargs.expectArgs("endswith method", {1, 1}, {0, 0});
-            auto suffix = vargs.args[0].get<std::string>();
-            return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
-          } else if (method->get_name() == "title") {
-            vargs.expectArgs("title method", {0, 0}, {0, 0});
-            auto res = str;
-            for (size_t i = 0, n = res.size(); i < n; ++i) {
-              if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]);
-              else res[i] = std::tolower(res[i]);
-            }
-            return res;
-          }
-        }
-        throw std::runtime_error("Unknown method: " + method->get_name());
-    }
-};
-
-class CallExpr : public Expression {
-public:
-    std::shared_ptr<Expression> object;
-    ArgumentsExpression args;
-    CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
-        : Expression(location), object(std::move(obj)), args(std::move(a)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        if (!object) throw std::runtime_error("CallExpr.object is null");
-        auto obj = object->evaluate(context);
-        if (!obj.is_callable()) {
-          throw std::runtime_error("Object is not callable: " + obj.dump(2));
-        }
-        auto vargs = args.evaluate(context);
-        return obj.call(context, vargs);
-    }
-};
-
-class FilterExpr : public Expression {
-    std::vector<std::shared_ptr<Expression>> parts;
-public:
-    FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
-      : Expression(location), parts(std::move(p)) {}
-    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
-        Value result;
-        bool first = true;
-        for (const auto& part : parts) {
-          if (!part) throw std::runtime_error("FilterExpr.part is null");
-          if (first) {
-            first = false;
-            result = part->evaluate(context);
-          } else {
-            if (auto ce = dynamic_cast<CallExpr*>(part.get())) {
-              auto target = ce->object->evaluate(context);
-              ArgumentsValue args = ce->args.evaluate(context);
-              args.args.insert(args.args.begin(), result);
-              result = target.call(context, args);
-            } else {
-              auto callable = part->evaluate(context);
-              ArgumentsValue args;
-              args.args.insert(args.args.begin(), result);
-              result = callable.call(context, args);
-            }
-          }
-        }
-        return result;
-    }
-
-    void prepend(std::shared_ptr<Expression> && e) {
-        parts.insert(parts.begin(), std::move(e));
-    }
-};
-
-class Parser {
-private:
-    using CharIterator = std::string::const_iterator;
-
-    std::shared_ptr<std::string> template_str;
-    CharIterator start, end, it;
-    Options options;
-
-    Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
-      if (!template_str) throw std::runtime_error("Template string is null");
-      start = it = this->template_str->begin();
-      end = this->template_str->end();
-    }
-
-    bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) {
-      if (space_handling == SpaceHandling::Strip) {
-        while (it != end && std::isspace(*it)) ++it;
-      }
-      return true;
-    }
-
-    std::unique_ptr<std::string> parseString() {
-      auto doParse = [&](char quote) -> std::unique_ptr<std::string> {
-        if (it == end || *it != quote) return nullptr;
-        std::string result;
-        bool escape = false;
-        for (++it; it != end; ++it) {
-          if (escape) {
-            escape = false;
-            switch (*it) {
-              case 'n': result += '\n'; break;
-              case 'r': result += '\r'; break;
-              case 't': result += '\t'; break;
-              case 'b': result += '\b'; break;
-              case 'f': result += '\f'; break;
-              case '\\': result += '\\'; break;
-              default:
-                if (*it == quote) {
-                  result += quote;
-                } else {
-                  result += *it;
-                }
-                break;
-            }
-          } else if (*it == '\\') {
-            escape = true;
-          } else if (*it == quote) {
-              ++it;
-            return std::make_unique<std::string>(std::move(result));
-          } else {
-            result += *it;
-          }
-        }
-        return nullptr;
-      };
-
-      consumeSpaces();
-      if (it == end) return nullptr;
-      if (*it == '"') return doParse('"');
-      if (*it == '\'') return doParse('\'');
-      return nullptr;
-    }
-
-    json parseNumber(CharIterator& it, const CharIterator& end) {
-        auto before = it;
-        consumeSpaces();
-        auto start = it;
-        bool hasDecimal = false;
-        bool hasExponent = false;
-
-        if (it != end && (*it == '-' || *it == '+')) ++it;
-
-        while (it != end) {
-          if (std::isdigit(*it)) {
-            ++it;
-          } else if (*it == '.') {
-            if (hasDecimal) throw std::runtime_error("Multiple decimal points");
-            hasDecimal = true;
-            ++it;
-          } else if (it != start && (*it == 'e' || *it == 'E')) {
-            if (hasExponent) throw std::runtime_error("Multiple exponents");
-            hasExponent = true;
-            ++it;
-          } else {
-            break;
-          }
-        }
-        if (start == it) {
-          it = before;
-          return json(); // No valid characters found
-        }
-
-        std::string str(start, it);
-        try {
-          return json::parse(str);
-        } catch (json::parse_error& e) {
-          throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")");
-          return json();
-        }
-    }
-
-    /** integer, float, bool, string */
-    std::shared_ptr<Value> parseConstant() {
-      auto start = it;
-      consumeSpaces();
-      if (it == end) return nullptr;
-      if (*it == '"' || *it == '\'') {
-        auto str = parseString();
-        if (str) return std::make_shared<Value>(*str);
-      }
-      static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
-      auto token = consumeToken(prim_tok);
-      if (!token.empty()) {
-        if (token == "true" || token == "True") return std::make_shared<Value>(true);
-        if (token == "false" || token == "False") return std::make_shared<Value>(false);
-        if (token == "None") return std::make_shared<Value>(nullptr);
-        throw std::runtime_error("Unknown constant token: " + token);
-      }
-
-      auto number = parseNumber(it, end);
-      if (!number.is_null()) return std::make_shared<Value>(number);
-
-      it = start;
-      return nullptr;
-    }
-
-    class expression_parsing_error : public std::runtime_error {
-        const CharIterator it;
-      public:
-        expression_parsing_error(const std::string & message, const CharIterator & it)
-            : std::runtime_error(message), it(it) {}
-        size_t get_pos(const CharIterator & begin) const {
-            return std::distance(begin, it);
-      }
-    };
-
-    bool peekSymbols(const std::vector<std::string> & symbols) const {
-        for (const auto & symbol : symbols) {
-            if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    std::vector<std::string> consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
-        auto start = it;
-        consumeSpaces(space_handling);
-        std::smatch match;
-        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
-            it += match[0].length();
-            std::vector<std::string> ret;
-            for (size_t i = 0, n = match.size(); i < n; ++i) {
-                ret.push_back(match[i].str());
-            }
-            return ret;
-        }
-        it = start;
-        return {};
-    }
-    std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
-        auto start = it;
-        consumeSpaces(space_handling);
-        std::smatch match;
-        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
-            it += match[0].length();
-            return match[0].str();
-        }
-        it = start;
-        return "";
-    }
-
-    std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) {
-        auto start = it;
-        consumeSpaces(space_handling);
-        if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) {
-            it += token.size();
-            return token;
-        }
-        it = start;
-        return "";
-    }
-
-    std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
-        auto left = parseLogicalOr();
-        if (it == end) return left;
-
-        if (!allow_if_expr) return left;
-
-        static std::regex if_tok(R"(if\b)");
-        if (consumeToken(if_tok).empty()) {
-          return left;
-        }
-
-        auto location = get_location();
-        auto [condition, else_expr] = parseIfExpression();
-        return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
-    }
-
-    Location get_location() const {
-        return {template_str, (size_t) std::distance(start, it)};
-    }
-
-    std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
-        auto condition = parseLogicalOr();
-        if (!condition) throw std::runtime_error("Expected condition expression");
-
-        static std::regex else_tok(R"(else\b)");
-        std::shared_ptr<Expression> else_expr;
-        if (!consumeToken(else_tok).empty()) {
-          else_expr = parseExpression();
-          if (!else_expr) throw std::runtime_error("Expected 'else' expression");
-        }
-        return std::pair(std::move(condition), std::move(else_expr));
-    }
-
-    std::shared_ptr<Expression> parseLogicalOr() {
-        auto left = parseLogicalAnd();
-        if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
-
-        static std::regex or_tok(R"(or\b)");
-        auto location = get_location();
-        while (!consumeToken(or_tok).empty()) {
-            auto right = parseLogicalAnd();
-            if (!right) throw std::runtime_error("Expected right side of 'or' expression");
-            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> parseLogicalNot() {
-        static std::regex not_tok(R"(not\b)");
-        auto location = get_location();
-
-        if (!consumeToken(not_tok).empty()) {
-          auto sub = parseLogicalNot();
-          if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
-          return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
-        }
-        return parseLogicalCompare();
-    }
-
-    std::shared_ptr<Expression> parseLogicalAnd() {
-        auto left = parseLogicalNot();
-        if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
-
-        static std::regex and_tok(R"(and\b)");
-        auto location = get_location();
-        while (!consumeToken(and_tok).empty()) {
-            auto right = parseLogicalNot();
-            if (!right) throw std::runtime_error("Expected right side of 'and' expression");
-            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> parseLogicalCompare() {
-        auto left = parseStringConcat();
-        if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
-
-        static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
-        static std::regex not_tok(R"(not\b)");
-        std::string op_str;
-        while (!(op_str = consumeToken(compare_tok)).empty()) {
-            auto location = get_location();
-            if (op_str == "is") {
-              auto negated = !consumeToken(not_tok).empty();
-
-              auto identifier = parseIdentifier();
-              if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
-
-              return std::make_shared<BinaryOpExpr>(
-                  left->location,
-                  std::move(left), std::move(identifier),
-                  negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
-            }
-            auto right = parseStringConcat();
-            if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression");
-            BinaryOpExpr::Op op;
-            if (op_str == "==") op = BinaryOpExpr::Op::Eq;
-            else if (op_str == "!=") op = BinaryOpExpr::Op::Ne;
-            else if (op_str == "<") op = BinaryOpExpr::Op::Lt;
-            else if (op_str == ">") op = BinaryOpExpr::Op::Gt;
-            else if (op_str == "<=") op = BinaryOpExpr::Op::Le;
-            else if (op_str == ">=") op = BinaryOpExpr::Op::Ge;
-            else if (op_str == "in") op = BinaryOpExpr::Op::In;
-            else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
-            else throw std::runtime_error("Unknown comparison operator: " + op_str);
-            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
-        }
-        return left;
-    }
-
-    Expression::Parameters parseParameters() {
-        consumeSpaces();
-        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
-
-        Expression::Parameters result;
-
-        while (it != end) {
-            if (!consumeToken(")").empty()) {
-                return result;
-            }
-            auto expr = parseExpression();
-            if (!expr) throw std::runtime_error("Expected expression in call args");
-
-            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
-                if (!consumeToken("=").empty()) {
-                    auto value = parseExpression();
-                    if (!value) throw std::runtime_error("Expected expression in for named arg");
-                    result.emplace_back(ident->get_name(), std::move(value));
-                } else {
-                    result.emplace_back(ident->get_name(), nullptr);
-                }
-            } else {
-                result.emplace_back(std::string(), std::move(expr));
-            }
-            if (consumeToken(",").empty()) {
-              if (consumeToken(")").empty()) {
-                throw std::runtime_error("Expected closing parenthesis in call args");
-              }
-              return result;
-            }
-        }
-        throw std::runtime_error("Expected closing parenthesis in call args");
-    }
-
-    ArgumentsExpression parseCallArgs() {
-        consumeSpaces();
-        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
-
-        ArgumentsExpression result;
-
-        while (it != end) {
-            if (!consumeToken(")").empty()) {
-                return result;
-            }
-            auto expr = parseExpression();
-            if (!expr) throw std::runtime_error("Expected expression in call args");
-
-            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
-                if (!consumeToken("=").empty()) {
-                    auto value = parseExpression();
-                    if (!value) throw std::runtime_error("Expected expression in for named arg");
-                    result.kwargs.emplace_back(ident->get_name(), std::move(value));
-                } else {
-                    result.args.emplace_back(std::move(expr));
-                }
-            } else {
-                result.args.emplace_back(std::move(expr));
-            }
-            if (consumeToken(",").empty()) {
-              if (consumeToken(")").empty()) {
-                throw std::runtime_error("Expected closing parenthesis in call args");
-              }
-              return result;
-            }
-        }
-        throw std::runtime_error("Expected closing parenthesis in call args");
-    }
-
-    std::shared_ptr<VariableExpr> parseIdentifier() {
-        static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
-        auto location = get_location();
-        auto ident = consumeToken(ident_regex);
-        if (ident.empty())
-          return nullptr;
-        return std::make_shared<VariableExpr>(location, ident);
-    }
-
-    std::shared_ptr<Expression> parseStringConcat() {
-        auto left = parseMathPow();
-        if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
-
-        static std::regex concat_tok(R"(~(?!\}))");
-        if (!consumeToken(concat_tok).empty()) {
-            auto right = parseLogicalAnd();
-            if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
-            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> parseMathPow() {
-        auto left = parseMathPlusMinus();
-        if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
-
-        while (!consumeToken("**").empty()) {
-            auto right = parseMathPlusMinus();
-            if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
-            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> parseMathPlusMinus() {
-        static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
-
-        auto left = parseMathMulDiv();
-        if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression");
-        std::string op_str;
-        while (!(op_str = consumeToken(plus_minus_tok)).empty()) {
-            auto right = parseMathMulDiv();
-            if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
-            auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
-            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> parseMathMulDiv() {
-        auto left = parseMathUnaryPlusMinus();
-        if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
-
-        static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))");
-        std::string op_str;
-        while (!(op_str = consumeToken(mul_div_tok)).empty()) {
-            auto right = parseMathUnaryPlusMinus();
-            if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
-            auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
-                : op_str == "**" ? BinaryOpExpr::Op::MulMul
-                : op_str == "/" ? BinaryOpExpr::Op::Div
-                : op_str == "//" ? BinaryOpExpr::Op::DivDiv
-                : BinaryOpExpr::Op::Mod;
-            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
-        }
-
-        if (!consumeToken("|").empty()) {
-            auto expr = parseMathMulDiv();
-            if (auto filter = dynamic_cast<FilterExpr*>(expr.get())) {
-                filter->prepend(std::move(left));
-                return expr;
-            } else {
-                std::vector<std::shared_ptr<Expression>> parts;
-                parts.emplace_back(std::move(left));
-                parts.emplace_back(std::move(expr));
-                return std::make_shared<FilterExpr>(get_location(), std::move(parts));
-            }
-        }
-        return left;
-    }
-
-    std::shared_ptr<Expression> call_func(const std::string & name, ArgumentsExpression && args) const {
-        return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
-    }
-
-    std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
-        static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
-        auto op_str = consumeToken(unary_plus_minus_tok);
-        auto expr = parseExpansion();
-        if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression");
-
-        if (!op_str.empty()) {
-            auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
-            return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
-        }
-        return expr;
-    }
-
-    std::shared_ptr<Expression> parseExpansion() {
-      static std::regex expansion_tok(R"(\*\*?)");
-      auto op_str = consumeToken(expansion_tok);
-      auto expr = parseValueExpression();
-      if (op_str.empty()) return expr;
-      if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression");
-      return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict);
-    }
-
-    std::shared_ptr<Expression> parseValueExpression() {
-      auto parseValue = [&]() -> std::shared_ptr<Expression> {
-        auto location = get_location();
-        auto constant = parseConstant();
-        if (constant) return std::make_shared<LiteralExpr>(location, *constant);
-
-        static std::regex null_regex(R"(null\b)");
-        if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
-
-        auto identifier = parseIdentifier();
-        if (identifier) return identifier;
-
-        auto braced = parseBracedExpressionOrArray();
-        if (braced) return braced;
-
-        auto array = parseArray();
-        if (array) return array;
-
-        auto dictionary = parseDictionary();
-        if (dictionary) return dictionary;
-
-        throw std::runtime_error("Expected value expression");
-      };
-
-      auto value = parseValue();
-
-      while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
-        if (!consumeToken("[").empty()) {
-            std::shared_ptr<Expression> index;
-            if (!consumeToken(":").empty()) {
-              auto slice_end = parseExpression();
-              index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
-            } else {
-              auto slice_start = parseExpression();
-              if (!consumeToken(":").empty()) {
-                consumeSpaces();
-                if (peekSymbols({ "]" })) {
-                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
-                } else {
-                  auto slice_end = parseExpression();
-                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
-                }
-              } else {
-                index = std::move(slice_start);
-              }
-            }
-            if (!index) throw std::runtime_error("Empty index in subscript");
-            if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
-
-            value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
-        } else if (!consumeToken(".").empty()) {
-            auto identifier = parseIdentifier();
-            if (!identifier) throw std::runtime_error("Expected identifier in subscript");
-
-            consumeSpaces();
-            if (peekSymbols({ "(" })) {
-              auto callParams = parseCallArgs();
-              value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
-            } else {
-              auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
-              value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
-            }
-        }
-        consumeSpaces();
-      }
-
-      if (peekSymbols({ "(" })) {
-        auto location = get_location();
-        auto callParams = parseCallArgs();
-        value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
-      }
-      return value;
-    }
-
-    std::shared_ptr<Expression> parseBracedExpressionOrArray() {
-        if (consumeToken("(").empty()) return nullptr;
-
-        auto expr = parseExpression();
-        if (!expr) throw std::runtime_error("Expected expression in braced expression");
-
-        if (!consumeToken(")").empty()) {
-            return expr;  // Drop the parentheses
-        }
-
-        std::vector<std::shared_ptr<Expression>> tuple;
-        tuple.emplace_back(std::move(expr));
-
-        while (it != end) {
-          if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple");
-          auto next = parseExpression();
-          if (!next) throw std::runtime_error("Expected expression in tuple");
-          tuple.push_back(std::move(next));
-
-          if (!consumeToken(")").empty()) {
-              return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
-          }
-        }
-        throw std::runtime_error("Expected closing parenthesis");
-    }
-
-    std::shared_ptr<Expression> parseArray() {
-        if (consumeToken("[").empty()) return nullptr;
-
-        std::vector<std::shared_ptr<Expression>> elements;
-        if (!consumeToken("]").empty()) {
-            return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
-        }
-        auto first_expr = parseExpression();
-        if (!first_expr) throw std::runtime_error("Expected first expression in array");
-        elements.push_back(std::move(first_expr));
-
-        while (it != end) {
-            if (!consumeToken(",").empty()) {
-              auto expr = parseExpression();
-              if (!expr) throw std::runtime_error("Expected expression in array");
-              elements.push_back(std::move(expr));
-            } else if (!consumeToken("]").empty()) {
-                return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
-            } else {
-                throw std::runtime_error("Expected comma or closing bracket in array");
-            }
-        }
-        throw std::runtime_error("Expected closing bracket");
-    }
-
-    std::shared_ptr<Expression> parseDictionary() {
-        if (consumeToken("{").empty()) return nullptr;
-
-        std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
-        if (!consumeToken("}").empty()) {
-            return std::make_shared<DictExpr>(get_location(), std::move(elements));
-        }
-
-        auto parseKeyValuePair = [&]() {
-            auto key = parseExpression();
-            if (!key) throw std::runtime_error("Expected key in dictionary");
-            if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
-            auto value = parseExpression();
-            if (!value) throw std::runtime_error("Expected value in dictionary");
-            elements.emplace_back(std::pair(std::move(key), std::move(value)));
-        };
-
-        parseKeyValuePair();
-
-        while (it != end) {
-            if (!consumeToken(",").empty()) {
-                parseKeyValuePair();
-            } else if (!consumeToken("}").empty()) {
-                return std::make_shared<DictExpr>(get_location(), std::move(elements));
-            } else {
-                throw std::runtime_error("Expected comma or closing brace in dictionary");
-            }
-        }
-        throw std::runtime_error("Expected closing brace");
-    }
-
-    SpaceHandling parsePreSpace(const std::string& s) const {
-        if (s == "-")
-          return SpaceHandling::Strip;
-        return SpaceHandling::Keep;
-    }
-
-    SpaceHandling parsePostSpace(const std::string& s) const {
-        if (s == "-") return SpaceHandling::Strip;
-        return SpaceHandling::Keep;
-    }
-
-    using TemplateTokenVector = std::vector<std::unique_ptr<TemplateToken>>;
-    using TemplateTokenIterator = TemplateTokenVector::const_iterator;
-
-    std::vector<std::string> parseVarNames() {
-      static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
-
-      std::vector<std::string> group;
-      if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
-      std::vector<std::string> varnames;
-      std::istringstream iss(group[1]);
-      std::string varname;
-      while (std::getline(iss, varname, ',')) {
-        varnames.push_back(strip(varname));
-      }
-      return varnames;
-    }
-
-    std::runtime_error unexpected(const TemplateToken & token) const {
-      return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type)
-        + error_location_suffix(*template_str, token.location.pos));
-    }
-    std::runtime_error unterminated(const TemplateToken & token) const {
-      return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type)
-        + error_location_suffix(*template_str, token.location.pos));
-    }
-
-    TemplateTokenVector tokenize() {
-      static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
-      static std::regex expr_open_regex(R"(\{\{([-~])?)");
-      static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
-      static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
-      static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
-      static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
-      static std::regex block_close_regex(R"(\s*([-~])?%\})");
-
-      TemplateTokenVector tokens;
-      std::vector<std::string> group;
-      std::string text;
-      std::smatch match;
-
-      try {
-        while (it != end) {
-          auto location = get_location();
-
-          if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
-            auto pre_space = parsePreSpace(group[1]);
-            auto content = group[2];
-            auto post_space = parsePostSpace(group[3]);
-            tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
-          } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
-            auto pre_space = parsePreSpace(group[1]);
-            auto expr = parseExpression();
-
-            if ((group = consumeTokenGroups(expr_close_regex)).empty()) {
-              throw std::runtime_error("Expected closing expression tag");
-            }
-
-            auto post_space = parsePostSpace(group[1]);
-            tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
-          } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
-            auto pre_space = parsePreSpace(group[1]);
-
-            std::string keyword;
-
-            auto parseBlockClose = [&]() -> SpaceHandling {
-              if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag");
-              return parsePostSpace(group[1]);
-            };
-
-            if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
-
-            if (keyword == "if") {
-              auto condition = parseExpression();
-              if (!condition) throw std::runtime_error("Expected condition in if block");
-
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
-            } else if (keyword == "elif") {
-              auto condition = parseExpression();
-              if (!condition) throw std::runtime_error("Expected condition in elif block");
-
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
-            } else if (keyword == "else") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "endif") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "for") {
-              static std::regex recursive_tok(R"(recursive\b)");
-              static std::regex if_tok(R"(if\b)");
-
-              auto varnames = parseVarNames();
-              static std::regex in_tok(R"(in\b)");
-              if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block");
-              auto iterable = parseExpression(/* allow_if_expr = */ false);
-              if (!iterable) throw std::runtime_error("Expected iterable in for block");
-
-              std::shared_ptr<Expression> condition;
-              if (!consumeToken(if_tok).empty()) {
-                condition = parseExpression();
-              }
-              auto recursive = !consumeToken(recursive_tok).empty();
-
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
-            } else if (keyword == "endfor") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "generation") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "endgeneration") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "set") {
-              static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
-
-              std::string ns;
-              std::vector<std::string> var_names;
-              std::shared_ptr<Expression> value;
-              if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
-                ns = group[1];
-                var_names.push_back(group[2]);
-
-                if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
-
-                value = parseExpression();
-                if (!value) throw std::runtime_error("Expected value in set block");
-              } else {
-                var_names = parseVarNames();
-
-                if (!consumeToken("=").empty()) {
-                  value = parseExpression();
-                  if (!value) throw std::runtime_error("Expected value in set block");
-                }
-              }
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
-            } else if (keyword == "endset") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "macro") {
-              auto macroname = parseIdentifier();
-              if (!macroname) throw std::runtime_error("Expected macro name in macro block");
-              auto params = parseParameters();
-
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
-            } else if (keyword == "endmacro") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "filter") {
-              auto filter = parseExpression();
-              if (!filter) throw std::runtime_error("Expected expression in filter block");
-
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
-            } else if (keyword == "endfilter") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
-            } else if (keyword == "break" || keyword == "continue") {
-              auto post_space = parseBlockClose();
-              tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
-            } else {
-              throw std::runtime_error("Unexpected block: " + keyword);
-            }
-          } else if (std::regex_search(it, end, match, non_text_open_regex)) {
-            if (!match.position()) {
-                if (match[0] != "{#")
-                    throw std::runtime_error("Internal error: Expected a comment");
-                throw std::runtime_error("Missing end of comment tag");
-            }
-            auto text_end = it + match.position();
-            text = std::string(it, text_end);
-            it = text_end;
-            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
-          } else {
-            text = std::string(it, end);
-            it = end;
-            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
-          }
-        }
-        return tokens;
-      } catch (const std::exception & e) {
-        throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it)));
-      }
-    }
-
-    std::shared_ptr<TemplateNode> parseTemplate(
-          const TemplateTokenIterator & begin,
-          TemplateTokenIterator & it,
-          const TemplateTokenIterator & end,
-          bool fully = false) const {
-        std::vector<std::shared_ptr<TemplateNode>> children;
-        while (it != end) {
-          const auto start = it;
-          const auto & token = *(it++);
-          if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
-              std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
-              cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
-
-              while (it != end && (*it)->type == TemplateToken::Type::Elif) {
-                  auto elif_token = dynamic_cast<ElifTemplateToken*>((*(it++)).get());
-                  cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end));
-              }
-
-              if (it != end && (*it)->type == TemplateToken::Type::Else) {
-                cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end));
-              }
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
-                  throw unterminated(**start);
-              }
-              children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
-          } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
-              auto body = parseTemplate(begin, it, end);
-              auto else_body = std::shared_ptr<TemplateNode>();
-              if (it != end && (*it)->type == TemplateToken::Type::Else) {
-                else_body = parseTemplate(begin, ++it, end);
-              }
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
-                  throw unterminated(**start);
-              }
-              children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
-          } else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
-              auto body = parseTemplate(begin, it, end);
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
-                  throw unterminated(**start);
-              }
-              // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
-              children.emplace_back(std::move(body));
-          } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
-              SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
-              SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
-
-              auto text = text_token->text;
-              if (post_space == SpaceHandling::Strip) {
-                static std::regex trailing_space_regex(R"(\s+$)");
-                text = std::regex_replace(text, trailing_space_regex, "");
-              } else if (options.lstrip_blocks && it != end) {
-                auto i = text.size();
-                while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
-                if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
-                  text.resize(i);
-                }
-              }
-              if (pre_space == SpaceHandling::Strip) {
-                static std::regex leading_space_regex(R"(^\s+)");
-                text = std::regex_replace(text, leading_space_regex, "");
-              } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
-                if (text.length() > 0 && text[0] == '\n') {
-                  text.erase(0, 1);
-                }
-              }
-              if (it == end && !options.keep_trailing_newline) {
-                auto i = text.size();
-                if (i > 0 && text[i - 1] == '\n') {
-                  i--;
-                  if (i > 0 && text[i - 1] == '\r') i--;
-                  text.resize(i);
-                }
-              }
-              children.emplace_back(std::make_shared<TextNode>(token->location, text));
-          } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
-              children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
-          } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
-            if (set_token->value) {
-              children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
-            } else {
-              auto value_template = parseTemplate(begin, it, end);
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
-                  throw unterminated(**start);
-              }
-              if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
-              if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
-              auto & name = set_token->var_names[0];
-              children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
-            }
-          } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
-              auto body = parseTemplate(begin, it, end);
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
-                  throw unterminated(**start);
-              }
-              children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
-          } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
-              auto body = parseTemplate(begin, it, end);
-              if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
-                  throw unterminated(**start);
-              }
-              children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
-          } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
-              // Ignore comments
-          } else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
-              children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
-          } else if (dynamic_cast<EndForTemplateToken*>(token.get())
-                  || dynamic_cast<EndSetTemplateToken*>(token.get())
-                  || dynamic_cast<EndMacroTemplateToken*>(token.get())
-                  || dynamic_cast<EndFilterTemplateToken*>(token.get())
-                  || dynamic_cast<EndIfTemplateToken*>(token.get())
-                  || dynamic_cast<ElseTemplateToken*>(token.get())
-                  || dynamic_cast<EndGenerationTemplateToken*>(token.get())
-                  || dynamic_cast<ElifTemplateToken*>(token.get())) {
-              it--;  // unconsume the token
-              break;  // exit the loop
-          } else {
-              throw unexpected(**(it-1));
-          }
-        }
-        if (fully && it != end) {
-            throw unexpected(**it);
-        }
-        if (children.empty()) {
-          return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
-        } else if (children.size() == 1) {
-          return std::move(children[0]);
-        } else {
-          return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
-        }
-    }
-
-public:
-
-    static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
-        Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
-        auto tokens = parser.tokenize();
-        TemplateTokenIterator begin = tokens.begin();
-        auto it = begin;
-        TemplateTokenIterator end = tokens.end();
-        return parser.parseTemplate(begin, it, end, /* full= */ true);
-    }
-};
-
-static Value simple_function(const std::string & fn_name, const std::vector<std::string> & params, const std::function<Value(const std::shared_ptr<Context> &, Value & args)> & fn) {
-  std::map<std::string, size_t> named_positions;
-  for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i;
-
-  return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) -> Value {
-    auto args_obj = Value::object();
-    std::vector<bool> provided_args(params.size());
-    for (size_t i = 0, n = args.args.size(); i < n; i++) {
-      auto & arg = args.args[i];
-      if (i < params.size()) {
-        args_obj.set(params[i], arg);
-        provided_args[i] = true;
-      } else {
-        throw std::runtime_error("Too many positional params for " + fn_name);
-      }
-    }
-    for (auto & [name, value] : args.kwargs) {
-      auto named_pos_it = named_positions.find(name);
-      if (named_pos_it == named_positions.end()) {
-        throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
-      }
-      provided_args[named_pos_it->second] = true;
-      args_obj.set(name, value);
-    }
-    return fn(context, args_obj);
-  });
-}
-
-inline std::shared_ptr<Context> Context::builtins() {
-  auto globals = Value::object();
-
-  globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-    throw std::runtime_error(args.at("message").get<std::string>());
-  }));
-  globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
-    return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true));
-  }));
-  globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto items = Value::array();
-    if (args.contains("object")) {
-      auto & obj = args.at("object");
-      if (obj.is_string()) {
-        auto json_obj = json::parse(obj.get<std::string>());
-        for (const auto & kv : json_obj.items()) {
-          items.push_back(Value::array({kv.key(), kv.value()}));
-        }
-      } else if (!obj.is_null()) {
-        for (auto & key : obj.keys()) {
-          items.push_back(Value::array({key, obj.at(key)}));
-        }
-      }
-    }
-    return items;
-  }));
-  globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto items = args.at("items");
-    if (!items.is_array()) throw std::runtime_error("object is not a list");
-    if (items.size() == 0) return Value();
-    return items.at(items.size() - 1);
-  }));
-  globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto & text = args.at("text");
-    return text.is_null() ? text : Value(strip(text.get<std::string>()));
-  }));
-  globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto text = args.at("text");
-    if (text.is_null()) return text;
-    std::string res;
-    auto str = text.get<std::string>();
-    std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower);
-    return Value(res);
-  }));
-  globals.set("default", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
-    args.expectArgs("default", {2, 3}, {0, 1});
-    auto & value = args.args[0];
-    auto & default_value = args.args[1];
-    bool boolean = false;
-    if (args.args.size() == 3) {
-      boolean = args.args[2].get<bool>();
-    } else {
-      Value bv = args.get_named("boolean");
-      if (!bv.is_null()) {
-        boolean = bv.get<bool>();
-      }
-    }
-    return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value;
-  }));
-  auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
-    return Value(html_escape(args.at("text").get<std::string>()));
-  });
-  globals.set("e", escape);
-  globals.set("escape", escape);
-  globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto sep = args.get<std::string>("sep", "");
-    auto first = std::make_shared<bool>(true);
-    return simple_function("", {}, [sep, first](const std::shared_ptr<Context> &, const Value &) -> Value {
-      if (*first) {
-        *first = false;
-        return "";
-      }
-      return sep;
-    });
-    return Value(html_escape(args.at("text").get<std::string>()));
-  }));
-  globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
-    return Value((int64_t) args.at("items").size());
-  }));
-  globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr<Context> &, Value & args) {
-    if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)");
-    auto & value = args.at("value");
-    auto keys = value.keys();
-    std::sort(keys.begin(), keys.end());
-    auto res = Value::array();
-    for (auto & key : keys) {
-      res.push_back(Value::array({key, value.at(key)}));
-    }
-    return res;
-  }));
-  globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto do_join = [](Value & items, const std::string & sep) {
-      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
-      std::ostringstream oss;
-      auto first = true;
-      for (size_t i = 0, n = items.size(); i < n; ++i) {
-        if (first) first = false;
-        else oss << sep;
-        oss << items.at(i).to_str();
-      }
-      return Value(oss.str());
-    };
-    auto sep = args.get<std::string>("d", "");
-    if (args.contains("items")) {
-        auto & items = args.at("items");
-        return do_join(items, sep);
-    } else {
-      return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr<Context> &, Value & args) {
-        auto & items = args.at("items");
-        if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump());
-        return do_join(items, sep);
-      });
-    }
-  }));
-  globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
-    auto ns = Value::object();
-    args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits<size_t>::max)()});
-    for (auto & [name, value] : args.kwargs) {
-      ns.set(name, value);
-    }
-    return ns;
-  }));
-  auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      return args.at("actual") == args.at("expected");
-  });
-  globals.set("equalto", equalto);
-  globals.set("==", equalto);
-  globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      auto & items = args.at("items");
-      return (int64_t) items.size();
-  }));
-  globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      return args.at("value").to_str();
-  }));
-  globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      return args.at("value").to_str();
-  }));
-  globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      return args.at("value").to_int();
-  }));
-  globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      auto & items = args.at("items");
-      if (!items.is_array()) throw std::runtime_error("object is not iterable");
-      return items;
-  }));
-  globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
-      auto & items = args.at("items");
-      if (!items.is_array()) throw std::runtime_error("object is not iterable");
-      std::unordered_set<Value> seen;
-      auto result = Value::array();
-      for (size_t i = 0, n = items.size(); i < n; i++) {
-        auto pair = seen.insert(items.at(i));
-        if (pair.second) {
-          result.push_back(items.at(i));
-        }
-      }
-      return result;
-  }));
-  auto make_filter = [](const Value & filter, Value & extra_args) -> Value {
-    return simple_function("", { "value" }, [=](const std::shared_ptr<Context> & context, Value & args) {
-      auto & value = args.at("value");
-      ArgumentsValue actual_args;
-      actual_args.args.emplace_back(value);
-      for (size_t i = 0, n = extra_args.size(); i < n; i++) {
-        actual_args.args.emplace_back(extra_args.at(i));
-      }
-      return filter.call(context, actual_args);
-    });
-  };
-  auto select_or_reject = [make_filter](bool is_select) {
-    return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-      args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
-      auto & items = args.args[0];
-      if (items.is_null())
-        return Value::array();
-      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
-
-      auto filter_fn = context->get(args.args[1]);
-      if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
-
-      auto filter_args = Value::array();
-      for (size_t i = 2, n = args.args.size(); i < n; i++) {
-        filter_args.push_back(args.args[i]);
-      }
-      auto filter = make_filter(filter_fn, filter_args);
-
-      auto res = Value::array();
-      for (size_t i = 0, n = items.size(); i < n; i++) {
-        auto & item = items.at(i);
-        ArgumentsValue filter_args;
-        filter_args.args.emplace_back(item);
-        auto pred_res = filter.call(context, filter_args);
-        if (pred_res.to_bool() == (is_select ? true : false)) {
-          res.push_back(item);
-        }
-      }
-      return res;
-    });
-  };
-  globals.set("select", select_or_reject(/* is_select= */ true));
-  globals.set("reject", select_or_reject(/* is_select= */ false));
-  globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-    auto res = Value::array();
-    if (args.args.size() == 1 &&
-      ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) {
-      auto & items = args.args[0];
-      auto attr_name = args.get_named("attribute");
-      auto default_value = args.get_named("default");
-      for (size_t i = 0, n = items.size(); i < n; i++) {
-        auto & item = items.at(i);
-        auto attr = item.get(attr_name);
-        res.push_back(attr.is_null() ? default_value : attr);
-      }
-    } else if (args.kwargs.empty() && args.args.size() >= 2) {
-      auto fn = context->get(args.args[1]);
-      if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
-      ArgumentsValue filter_args { {Value()}, {} };
-      for (size_t i = 2, n = args.args.size(); i < n; i++) {
-        filter_args.args.emplace_back(args.args[i]);
-      }
-      for (size_t i = 0, n = args.args[0].size(); i < n; i++) {
-        auto & item = args.args[0].at(i);
-        filter_args.args[0] = item;
-        res.push_back(fn.call(context, filter_args));
-      }
-    } else {
-      throw std::runtime_error("Invalid or unsupported arguments for map");
-    }
-    return res;
-  }));
-  globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr<Context> &, Value & args) {
-    auto text = args.at("text").get<std::string>();
-    auto first = args.get<bool>("first", false);
-    std::string out;
-    std::string indent(args.get<int64_t>("indent", 0), ' ');
-    std::istringstream iss(text);
-    std::string line;
-    auto is_first = true;
-    while (std::getline(iss, line, '\n')) {
-      auto needs_indent = !is_first || first;
-      if (is_first) is_first = false;
-      else out += "\n";
-      if (needs_indent) out += indent;
-      out += line;
-    }
-    if (!text.empty() && text.back() == '\n') out += "\n";
-    return out;
-  }));
-  auto select_or_reject_attr = [](bool is_select) {
-    return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-      args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
-      auto & items = args.args[0];
-      if (items.is_null())
-        return Value::array();
-      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
-      auto attr_name = args.args[1].get<std::string>();
-
-      bool has_test = false;
-      Value test_fn;
-      ArgumentsValue test_args {{Value()}, {}};
-      if (args.args.size() >= 3) {
-        has_test = true;
-        test_fn = context->get(args.args[2]);
-        if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
-        for (size_t i = 3, n = args.args.size(); i < n; i++) {
-          test_args.args.emplace_back(args.args[i]);
-        }
-        test_args.kwargs = args.kwargs;
-      }
-
-      auto res = Value::array();
-      for (size_t i = 0, n = items.size(); i < n; i++) {
-        auto & item = items.at(i);
-        auto attr = item.get(attr_name);
-        if (has_test) {
-          test_args.args[0] = attr;
-          if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
-            res.push_back(item);
-          }
-        } else {
-          res.push_back(attr);
-        }
-      }
-      return res;
-    });
-  };
-  globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
-  globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
-  globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
-    std::vector<int64_t> startEndStep(3);
-    std::vector<bool> param_set(3);
-    if (args.args.size() == 1) {
-      startEndStep[1] = args.args[0].get<int64_t>();
-      param_set[1] = true;
-    } else {
-      for (size_t i = 0; i < args.args.size(); i++) {
-        auto & arg = args.args[i];
-        auto v = arg.get<int64_t>();
-        startEndStep[i] = v;
-        param_set[i] = true;
-        }
-      }
-      for (auto & [name, value] : args.kwargs) {
-        size_t i;
-        if (name == "start") i = 0;
-        else if (name == "end") i = 1;
-        else if (name == "step") i = 2;
-        else throw std::runtime_error("Unknown argument " + name + " for function range");
-
-        if (param_set[i]) {
-          throw std::runtime_error("Duplicate argument " + name + " for function range");
-        }
-        startEndStep[i] = value.get<int64_t>();
-        param_set[i] = true;
-    }
-    if (!param_set[1]) {
-      throw std::runtime_error("Missing required argument 'end' for function range");
-    }
-    int64_t start = param_set[0] ? startEndStep[0] : 0;
-    int64_t end = startEndStep[1];
-    int64_t step = param_set[2] ? startEndStep[2] : 1;
-
-    auto res = Value::array();
-    if (step > 0) {
-      for (int64_t i = start; i < end; i += step) {
-        res.push_back(Value(i));
-      }
-    } else {
-      for (int64_t i = start; i > end; i += step) {
-        res.push_back(Value(i));
-      }
-    }
-    return res;
-  }));
-
-  return std::make_shared<Context>(std::move(globals));
-}
-
-inline std::shared_ptr<Context> Context::make(Value && values, const std::shared_ptr<Context> & parent) {
-  return std::make_shared<Context>(values.is_null() ? Value::object() : std::move(values), parent);
-}
-
-}  // namespace minja
diff --git a/common/minja/chat-template.hpp b/common/minja/chat-template.hpp
new file mode 100644 (file)
index 0000000..882ba41
--- /dev/null
@@ -0,0 +1,529 @@
+/*
+    Copyright 2024 Google LLC
+
+    Use of this source code is governed by an MIT-style
+    license that can be found in the LICENSE file or at
+    https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include "minja.hpp"
+#include <json.hpp>
+#include <string>
+#include <vector>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+struct chat_template_caps {
+    bool supports_tools = false;
+    bool supports_tool_calls = false;
+    bool supports_tool_responses = false;
+    bool supports_system_role = false;
+    bool supports_parallel_tool_calls = false;
+    bool supports_tool_call_id = false;
+    // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
+    // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
+    bool requires_object_arguments = false;
+    // CohereForAI/c4ai-command-r-plus simple variant
+    bool requires_non_null_content = false;
+    // MiniMaxAI/MiniMax-Text-01 special
+    bool requires_typed_content = false;
+};
+
+struct chat_template_inputs {
+    nlohmann::ordered_json messages;
+    nlohmann::ordered_json tools;
+    bool add_generation_prompt = true;
+    nlohmann::ordered_json extra_context;
+    std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
+};
+
+struct chat_template_options {
+    bool apply_polyfills = true;
+    bool use_bos_token = true;
+    bool use_eos_token = true;
+    bool define_strftime_now = true;
+
+    bool polyfill_tools = true;
+    bool polyfill_tool_call_examples = true;
+    bool polyfill_tool_calls = true;
+    bool polyfill_tool_responses = true;
+    bool polyfill_system_role = true;
+    bool polyfill_object_arguments = true;
+    bool polyfill_typed_content = true;
+};
+
+class chat_template {
+
+  private:
+    chat_template_caps caps_;
+    std::string source_;
+    std::string bos_token_;
+    std::string eos_token_;
+    std::shared_ptr<minja::TemplateNode> template_root_;
+    std::string tool_call_example_;
+
+    std::string try_raw_render(
+        const nlohmann::ordered_json & messages,
+        const nlohmann::ordered_json & tools,
+        bool add_generation_prompt,
+        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
+    {
+        try {
+            chat_template_inputs inputs;
+            inputs.messages = messages;
+            inputs.tools = tools;
+            inputs.add_generation_prompt = add_generation_prompt;
+            inputs.extra_context = extra_context;
+            // Use fixed date for tests
+            inputs.now = std::chrono::system_clock::from_time_t(0);
+
+            chat_template_options opts;
+            opts.apply_polyfills = false;
+
+            auto prompt = apply(inputs, opts);
+            // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
+            return prompt;
+        } catch (const std::exception & e) {
+            // fprintf(stderr, "try_raw_render error: %s\n", e.what());
+            return "";
+        }
+    }
+
+  public:
+
+    chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
+        : source_(source), bos_token_(bos_token), eos_token_(eos_token)
+    {
+        template_root_ = minja::Parser::parse(source_, {
+            /* .trim_blocks = */ true,
+            /* .lstrip_blocks = */ true,
+            /* .keep_trailing_newline = */ false,
+        });
+
+        auto contains = [](const std::string & haystack, const std::string & needle) {
+            return haystack.find(needle) != std::string::npos;
+        };
+
+        const std::string user_needle = "<User Needle>";
+        const std::string sys_needle = "<System Needle>";
+        const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
+        const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
+
+        caps_.requires_typed_content =
+            !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
+            && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
+
+        const auto dummy_user_msg = caps_.requires_typed_content
+            ? dummy_typed_user_msg
+            : dummy_str_user_msg;
+        const json needle_system_msg = {
+            {"role", "system"},
+            {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
+        };
+
+        caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
+
+        auto out = try_raw_render(json::array({
+            dummy_user_msg
+        }), json::array({
+            {
+                {"name", "some_tool"},
+                {"type", "function"},
+                {"function", {
+                    {"name", "some_tool"},
+                    {"description", "Some tool."},
+                    {"parameters", {
+                        {"type", "object"},
+                        {"properties", {
+                            {"arg", {
+                                {"type", "string"},
+                                {"description", "Some argument."},
+                            }},
+                        }},
+                        {"required", json::array({ "arg" })},
+                    }},
+                }},
+            },
+        }), false);
+        caps_.supports_tools = contains(out, "some_tool");
+
+        auto make_tool_calls_msg = [&](const json & tool_calls) {
+            return json {
+                {"role", "assistant"},
+                {"content", nullptr},
+                {"tool_calls", tool_calls},
+            };
+        };
+        auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
+            return json {
+                {"id", "call_1___"},
+                {"type", "function"},
+                {"function", {
+                    {"arguments", arguments},
+                    {"name", tool_name},
+                }},
+            };
+        };
+        const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
+
+        // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
+        out = try_raw_render(json::array({
+            dummy_user_msg,
+            make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
+        }), {}, false);
+        auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+        out = try_raw_render(json::array({
+            dummy_user_msg,
+            make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
+        }), {}, false);
+        auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+
+        caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
+        caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
+        auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
+        auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
+        caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
+
+        if (caps_.supports_tool_calls) {
+            auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
+            auto tc1 = make_tool_call("test_tool1", dummy_args);
+            auto tc2 = make_tool_call("test_tool2", dummy_args);
+            auto out = try_raw_render(json::array({
+                dummy_user_msg,
+                make_tool_calls_msg(json::array({tc1, tc2})),
+            }), {}, false);
+            caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
+
+            out = try_raw_render(json::array({
+                dummy_user_msg,
+                make_tool_calls_msg(json::array({tc1})),
+                {
+                    {"role", "tool"},
+                    {"name", "test_tool1"},
+                    {"content", "Some response!"},
+                    {"tool_call_id", "call_911_"},
+                }
+            }), {}, false);
+            caps_.supports_tool_responses = contains(out, "Some response!");
+            caps_.supports_tool_call_id = contains(out, "call_911_");
+        }
+
+        try {
+            if (!caps_.supports_tools) {
+                const json user_msg {
+                    {"role", "user"},
+                    {"content", "Hey"},
+                };
+                const json args {
+                    {"arg1", "some_value"},
+                };
+                const json tool_call_msg {
+                    {"role", "assistant"},
+                    {"content", nullptr},
+                    {"tool_calls", json::array({
+                        {
+                            // TODO: detect if requires numerical id or fixed length == 6 like Nemo
+                            {"id", "call_1___"},
+                            {"type", "function"},
+                            {"function", {
+                                {"name", "tool_name"},
+                                {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
+                            }},
+                        },
+                    })},
+                };
+                std::string prefix, full;
+                {
+                    chat_template_inputs inputs;
+                    inputs.messages = json::array({user_msg});
+                    inputs.add_generation_prompt = true;
+                    prefix = apply(inputs);
+                }
+                {
+                    chat_template_inputs inputs;
+                    inputs.messages = json::array({user_msg, tool_call_msg});
+                    inputs.add_generation_prompt = false;
+                    full = apply(inputs);
+                }
+                auto eos_pos_last = full.rfind(eos_token_);
+                if (eos_pos_last == prefix.size() - eos_token_.size() ||
+                      (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
+                    full = full.substr(0, eos_pos_last);
+                }
+                size_t common_prefix_length = 0;
+                for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
+                    if (prefix[i] != full[i]) {
+                        break;
+                    }
+                    if (prefix[i] == '<') {
+                        // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
+                        // but it removes thinking tags for past messages.
+                        // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
+                        continue;
+                    }
+                    common_prefix_length = i + 1;
+                }
+                auto example = full.substr(common_prefix_length);
+                if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
+                    fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
+                } else {
+                    tool_call_example_ = example;
+                }
+            }
+        } catch (const std::exception & e) {
+            fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
+        }
+    }
+
+    const std::string & source() const { return source_; }
+    const std::string & bos_token() const { return bos_token_; }
+    const std::string & eos_token() const { return eos_token_; }
+    const chat_template_caps & original_caps() const { return caps_; }
+
+    // Deprecated, please use the form with chat_template_inputs and chat_template_options
+    std::string apply(
+        const nlohmann::ordered_json & messages,
+        const nlohmann::ordered_json & tools,
+        bool add_generation_prompt,
+        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
+        bool apply_polyfills = true)
+    {
+        fprintf(stderr, "[%s] Deprecated!\n", __func__);
+        chat_template_inputs inputs;
+        inputs.messages = messages;
+        inputs.tools = tools;
+        inputs.add_generation_prompt = add_generation_prompt;
+        inputs.extra_context = extra_context;
+        inputs.now = std::chrono::system_clock::now();
+
+        chat_template_options opts;
+        opts.apply_polyfills = apply_polyfills;
+
+        return apply(inputs, opts);
+    }
+
+    std::string apply(
+        const chat_template_inputs & inputs,
+        const chat_template_options & opts = chat_template_options()) const
+    {
+        json actual_messages;
+
+        auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
+        auto has_tool_calls = false;
+        auto has_tool_responses = false;
+        auto has_string_content = false;
+        for (const auto & message : inputs.messages) {
+            if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
+                has_tool_calls = true;
+            }
+            if (message.contains("role") && message["role"] == "tool") {
+                has_tool_responses = true;
+            }
+            if (message.contains("content") && message["content"].is_string()) {
+                has_string_content = true;
+            }
+        }
+
+        auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
+        auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
+        auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
+        auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
+        auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
+        auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
+        auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
+
+        auto needs_polyfills = opts.apply_polyfills && (false
+            || polyfill_system_role
+            || polyfill_tools
+            || polyfill_tool_calls
+            || polyfill_tool_responses
+            || polyfill_object_arguments
+            || polyfill_typed_content
+        );
+
+        if (needs_polyfills) {
+            actual_messages = json::array();
+
+            auto add_message = [&](const json & msg) {
+                if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
+                    actual_messages.push_back({
+                        {"role", msg.at("role")},
+                        {"content", {{
+                            {"type", "text"},
+                            {"text", msg.at("content")},
+                        }}},
+                    });
+                } else {
+                    actual_messages.push_back(msg);
+                }
+            };
+
+            std::string pending_system;
+            auto flush_sys = [&]() {
+                if (!pending_system.empty()) {
+                    add_message({
+                        {"role", "user"},
+                        {"content", pending_system},
+                    });
+                    pending_system.clear();
+                }
+            };
+
+            json adjusted_messages;
+            if (polyfill_tools) {
+                adjusted_messages = add_system(inputs.messages,
+                    "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
+                    (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
+            } else {
+                adjusted_messages = inputs.messages;
+            }
+
+            for (const auto & message_ : adjusted_messages) {
+                auto message = message_;
+                if (!message.contains("role") || !message.contains("content")) {
+                    throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
+                }
+                std::string role = message.at("role");
+
+                if (message.contains("tool_calls")) {
+                    if (polyfill_object_arguments || polyfill_tool_calls) {
+                        for (auto & tool_call : message.at("tool_calls")) {
+                            if (tool_call["type"] == "function") {
+                                auto & function = tool_call.at("function");
+                                auto & arguments = function.at("arguments");
+                                if (arguments.is_string()) {
+                                    try {
+                                        arguments = json::parse(arguments.get<std::string>());
+                                    } catch (const std::exception & ecvt) {
+                                        fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
+                                    }
+                                }
+                            }
+                        }
+                    }
+                    if (polyfill_tool_calls) {
+                        auto content = message.at("content");
+                        auto tool_calls = json::array();
+                        for (const auto & tool_call : message.at("tool_calls")) {
+                            if (tool_call.at("type") != "function") {
+                                continue;
+                            }
+                            const auto & function = tool_call.at("function");
+                            auto tc = json {
+                                {"name", function.at("name")},
+                                {"arguments", function.at("arguments")},
+                            };
+                            if (tool_call.contains("id")) {
+                                tc["id"] = tool_call["id"];
+                            }
+                            tool_calls.push_back(tc);
+                        }
+                        auto obj = json {
+                            {"tool_calls", tool_calls},
+                        };
+                        if (!content.is_null() && content != "") {
+                            obj["content"] = content;
+                        }
+                        message["content"] = obj.dump(2);
+                        message.erase("tool_calls");
+                    }
+                }
+                if (polyfill_tool_responses && role == "tool") {
+                    message["role"] = "user";
+                    auto obj = json {
+                        {"tool_response", {
+                            {"content", message.at("content")},
+                        }},
+                    };
+                    if (message.contains("name")) {
+                        obj["tool_response"]["name"] = message.at("name");
+                    }
+                    if (message.contains("tool_call_id")) {
+                        obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
+                    }
+                    message["content"] = obj.dump(2);
+                    message.erase("name");
+                }
+
+                if (!message["content"].is_null() && polyfill_system_role) {
+                    std::string content = message.at("content");
+                    if (role == "system") {
+                        if (!pending_system.empty()) pending_system += "\n";
+                        pending_system += content;
+                        continue;
+                    } else {
+                        if (role == "user") {
+                            if (!pending_system.empty()) {
+                                message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
+                                pending_system.clear();
+                            }
+                        } else {
+                            flush_sys();
+                        }
+                    }
+                }
+                add_message(message);
+            }
+            flush_sys();
+        } else {
+            actual_messages = inputs.messages;
+        }
+
+        auto context = minja::Context::make(json({
+            {"messages", actual_messages},
+            {"add_generation_prompt", inputs.add_generation_prompt},
+        }));
+        context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
+        context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
+        if (opts.define_strftime_now) {
+            auto now = inputs.now;
+            context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
+                args.expectArgs("strftime_now", {1, 1}, {0, 0});
+                auto format = args.args[0].get<std::string>();
+
+                auto time = std::chrono::system_clock::to_time_t(now);
+                auto local_time = *std::localtime(&time);
+                std::ostringstream ss;
+                ss << std::put_time(&local_time, format.c_str());
+                return ss.str();
+            }));
+        }
+        if (!inputs.tools.is_null()) {
+            context->set("tools", minja::Value(inputs.tools));
+        }
+        if (!inputs.extra_context.is_null()) {
+            for (auto & kv : inputs.extra_context.items()) {
+                context->set(kv.key(), minja::Value(kv.value()));
+            }
+        }
+
+        auto ret = template_root_->render(context);
+        // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
+        // fprintf(stderr, "apply: %s\n\n", ret.c_str());
+        return ret;
+    }
+
+    static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
+        json messages_with_system = messages;
+
+        if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
+            std::string existing_system = messages_with_system.at(0).at("content");
+            messages_with_system[0] = json {
+                {"role", "system"},
+                {"content", existing_system + "\n\n" + system_prompt},
+            };
+        } else {
+            messages_with_system.insert(messages_with_system.begin(), json {
+                {"role", "system"},
+                {"content", system_prompt},
+            });
+        }
+        return messages_with_system;
+    }
+};
+
+}  // namespace minja
diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp
new file mode 100644 (file)
index 0000000..c58dd66
--- /dev/null
@@ -0,0 +1,2883 @@
+/*
+    Copyright 2024 Google LLC
+
+    Use of this source code is governed by an MIT-style
+    license that can be found in the LICENSE file or at
+    https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <regex>
+#include <memory>
+#include <stdexcept>
+#include <sstream>
+#include <unordered_set>
+#include <json.hpp>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+class Context;
+
+struct Options {
+    bool trim_blocks;  // removes the first newline after a block
+    bool lstrip_blocks;  // removes leading whitespace on the line of the block
+    bool keep_trailing_newline;  // don't remove last newline
+};
+
+struct ArgumentsValue;
+
+inline std::string normalize_newlines(const std::string & s) {
+#ifdef _WIN32
+  static const std::regex nl_regex("\r\n");
+  return std::regex_replace(s, nl_regex, "\n");
+#else
+  return s;
+#endif
+}
+
+/* Values that behave roughly like in Python. */
+class Value : public std::enable_shared_from_this<Value> {
+public:
+  using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+  using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+
+private:
+  using ObjectType = nlohmann::ordered_map<json, Value>;  // Only contains primitive keys
+  using ArrayType = std::vector<Value>;
+
+  std::shared_ptr<ArrayType> array_;
+  std::shared_ptr<ObjectType> object_;
+  std::shared_ptr<CallableType> callable_;
+  json primitive_;
+
+  Value(const std::shared_ptr<ArrayType> & array) : array_(array) {}
+  Value(const std::shared_ptr<ObjectType> & object) : object_(object) {}
+  Value(const std::shared_ptr<CallableType> & callable) : object_(std::make_shared<ObjectType>()), callable_(callable) {}
+
+  /* Python-style string repr */
+  static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') {
+    if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump());
+    auto s = primitive.dump();
+    if (string_quote == '"' || s.find('\'') != std::string::npos) {
+      out << s;
+      return;
+    }
+    // Reuse json dump, just changing string quotes
+    out << string_quote;
+    for (size_t i = 1, n = s.size() - 1; i < n; ++i) {
+      if (s[i] == '\\' && s[i + 1] == '"') {
+        out << '"';
+        i++;
+      } else if (s[i] == string_quote) {
+        out << '\\' << string_quote;
+      } else {
+        out << s[i];
+      }
+    }
+    out << string_quote;
+  }
+  void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
+    auto print_indent = [&](int level) {
+      if (indent > 0) {
+          out << "\n";
+          for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
+      }
+    };
+    auto print_sub_sep = [&]() {
+      out << ',';
+      if (indent < 0) out << ' ';
+      else print_indent(level + 1);
+    };
+
+    auto string_quote = to_json ? '"' : '\'';
+
+    if (is_null()) out << "null";
+    else if (array_) {
+      out << "[";
+      print_indent(level + 1);
+      for (size_t i = 0; i < array_->size(); ++i) {
+        if (i) print_sub_sep();
+        (*array_)[i].dump(out, indent, level + 1, to_json);
+      }
+      print_indent(level);
+      out << "]";
+    } else if (object_) {
+      out << "{";
+      print_indent(level + 1);
+      for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) {
+        if (it != begin) print_sub_sep();
+        if (it->first.is_string()) {
+          dump_string(it->first, out, string_quote);
+        } else {
+          out << string_quote << it->first.dump() << string_quote;
+        }
+        out << ": ";
+        it->second.dump(out, indent, level + 1, to_json);
+      }
+      print_indent(level);
+      out << "}";
+    } else if (callable_) {
+      throw std::runtime_error("Cannot dump callable to JSON");
+    } else if (is_boolean() && !to_json) {
+      out << (this->to_bool() ? "True" : "False");
+    } else if (is_string() && !to_json) {
+      dump_string(primitive_, out, string_quote);
+    } else {
+      out << primitive_.dump();
+    }
+  }
+
+public:
+  Value() {}
+  Value(const bool& v) : primitive_(v) {}
+  Value(const int64_t & v) : primitive_(v) {}
+  Value(const double& v) : primitive_(v) {}
+  Value(const std::nullptr_t &) {}
+  Value(const std::string & v) : primitive_(v) {}
+  Value(const char * v) : primitive_(std::string(v)) {}
+
+  Value(const json & v) {
+    if (v.is_object()) {
+      auto object = std::make_shared<ObjectType>();
+      for (auto it = v.begin(); it != v.end(); ++it) {
+        (*object)[it.key()] = it.value();
+      }
+      object_ = std::move(object);
+    } else if (v.is_array()) {
+      auto array = std::make_shared<ArrayType>();
+      for (const auto& item : v) {
+        array->push_back(Value(item));
+      }
+      array_ = array;
+    } else {
+      primitive_ = v;
+    }
+  }
+
+  std::vector<Value> keys() {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    std::vector<Value> res;
+    for (const auto& item : *object_) {
+      res.push_back(item.first);
+    }
+    return res;
+  }
+
+  size_t size() const {
+    if (is_object()) return object_->size();
+    if (is_array()) return array_->size();
+    if (is_string()) return primitive_.get<std::string>().length();
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+
+  static Value array(const std::vector<Value> values = {}) {
+    auto array = std::make_shared<ArrayType>();
+    for (const auto& item : values) {
+      array->push_back(item);
+    }
+    return Value(array);
+  }
+  static Value object(const std::shared_ptr<ObjectType> object = std::make_shared<ObjectType>()) {
+    return Value(object);
+  }
+  static Value callable(const CallableType & callable) {
+    return Value(std::make_shared<CallableType>(callable));
+  }
+
+  void insert(size_t index, const Value& v) {
+    if (!array_)
+      throw std::runtime_error("Value is not an array: " + dump());
+    array_->insert(array_->begin() + index, v);
+  }
+  void push_back(const Value& v) {
+    if (!array_)
+      throw std::runtime_error("Value is not an array: " + dump());
+    array_->push_back(v);
+  }
+  Value pop(const Value& index) {
+    if (is_array()) {
+      if (array_->empty())
+        throw std::runtime_error("pop from empty list");
+      if (index.is_null()) {
+        auto ret = array_->back();
+        array_->pop_back();
+        return ret;
+      } else if (!index.is_number_integer()) {
+        throw std::runtime_error("pop index must be an integer: " + index.dump());
+      } else {
+        auto i = index.get<int>();
+        if (i < 0 || i >= static_cast<int>(array_->size()))
+          throw std::runtime_error("pop index out of range: " + index.dump());
+        auto it = array_->begin() + (i < 0 ? array_->size() + i : i);
+        auto ret = *it;
+        array_->erase(it);
+        return ret;
+      }
+    } else if (is_object()) {
+      if (!index.is_hashable())
+        throw std::runtime_error("Unashable type: " + index.dump());
+      auto it = object_->find(index.primitive_);
+      if (it == object_->end())
+        throw std::runtime_error("Key not found: " + index.dump());
+      auto ret = it->second;
+      object_->erase(it);
+      return ret;
+    } else {
+      throw std::runtime_error("Value is not an array or object: " + dump());
+    }
+  }
+  Value get(const Value& key) {
+    if (array_) {
+      if (!key.is_number_integer()) {
+        return Value();
+      }
+      auto index = key.get<int>();
+      return array_->at(index < 0 ? array_->size() + index : index);
+    } else if (object_) {
+      if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+      auto it = object_->find(key.primitive_);
+      if (it == object_->end()) return Value();
+      return it->second;
+    }
+    return Value();
+  }
+  void set(const Value& key, const Value& value) {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    (*object_)[key.primitive_] = value;
+  }
+  Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
+    if (!callable_) throw std::runtime_error("Value is not callable: " + dump());
+    return (*callable_)(context, args);
+  }
+
+  bool is_object() const { return !!object_; }
+  bool is_array() const { return !!array_; }
+  bool is_callable() const { return !!callable_; }
+  bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; }
+  bool is_boolean() const { return primitive_.is_boolean(); }
+  bool is_number_integer() const { return primitive_.is_number_integer(); }
+  bool is_number_float() const { return primitive_.is_number_float(); }
+  bool is_number() const { return primitive_.is_number(); }
+  bool is_string() const { return primitive_.is_string(); }
+  bool is_iterable() const { return is_array() || is_object() || is_string(); }
+
+  bool is_primitive() const { return !array_ && !object_ && !callable_; }
+  bool is_hashable() const { return is_primitive(); }
+
+  bool empty() const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_string()) return primitive_.empty();
+    if (is_array()) return array_->empty();
+    if (is_object()) return object_->empty();
+    return false;
+  }
+
+  void for_each(const std::function<void(Value &)> & callback) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (array_) {
+      for (auto& item : *array_) {
+        callback(item);
+      }
+    } else if (object_) {
+      for (auto & item : *object_) {
+        Value key(item.first);
+        callback(key);
+      }
+    } else if (is_string()) {
+      for (char c : primitive_.get<std::string>()) {
+        auto val = Value(std::string(1, c));
+        callback(val);
+      }
+    } else {
+      throw std::runtime_error("Value is not iterable: " + dump());
+    }
+  }
+
+  bool to_bool() const {
+    if (is_null()) return false;
+    if (is_boolean()) return get<bool>();
+    if (is_number()) return get<double>() != 0;
+    if (is_string()) return !get<std::string>().empty();
+    if (is_array()) return !empty();
+    return true;
+  }
+
+  int64_t to_int() const {
+    if (is_null()) return 0;
+    if (is_boolean()) return get<bool>() ? 1 : 0;
+    if (is_number()) return static_cast<int64_t>(get<double>());
+    if (is_string()) {
+      try {
+        return std::stol(get<std::string>());
+      } catch (const std::exception &) {
+        return 0;
+      }
+    }
+    return 0;
+  }
+
+  bool operator<(const Value & other) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_number() && other.is_number()) return get<double>() < other.get<double>();
+    if (is_string() && other.is_string()) return get<std::string>() < other.get<std::string>();
+    throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump());
+  }
+  bool operator>=(const Value & other) const { return !(*this < other); }
+
+  bool operator>(const Value & other) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_number() && other.is_number()) return get<double>() > other.get<double>();
+    if (is_string() && other.is_string()) return get<std::string>() > other.get<std::string>();
+    throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump());
+  }
+  bool operator<=(const Value & other) const { return !(*this > other); }
+
+  bool operator==(const Value & other) const {
+    if (callable_ || other.callable_) {
+      if (callable_.get() != other.callable_.get()) return false;
+    }
+    if (array_) {
+      if (!other.array_) return false;
+      if (array_->size() != other.array_->size()) return false;
+      for (size_t i = 0; i < array_->size(); ++i) {
+        if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false;
+      }
+      return true;
+    } else if (object_) {
+      if (!other.object_) return false;
+      if (object_->size() != other.object_->size()) return false;
+      for (const auto& item : *object_) {
+        if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false;
+      }
+      return true;
+    } else {
+      return primitive_ == other.primitive_;
+    }
+  }
+  bool operator!=(const Value & other) const { return !(*this == other); }
+
+  bool contains(const char * key) const { return contains(std::string(key)); }
+  bool contains(const std::string & key) const {
+    if (array_) {
+      return false;
+    } else if (object_) {
+      return object_->find(key) != object_->end();
+    } else {
+      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+    }
+  }
+  bool contains(const Value & value) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (array_) {
+      for (const auto& item : *array_) {
+        if (item.to_bool() && item == value) return true;
+      }
+      return false;
+    } else if (object_) {
+      if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
+      return object_->find(value.primitive_) != object_->end();
+    } else {
+      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+    }
+  }
+  void erase(size_t index) {
+    if (!array_) throw std::runtime_error("Value is not an array: " + dump());
+    array_->erase(array_->begin() + index);
+  }
+  void erase(const std::string & key) {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    object_->erase(key);
+  }
+  const Value& at(const Value & index) const {
+    return const_cast<Value*>(this)->at(index);
+  }
+  Value& at(const Value & index) {
+    if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    if (is_array()) return array_->at(index.get<int>());
+    if (is_object()) return object_->at(index.primitive_);
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+  const Value& at(size_t index) const {
+    return const_cast<Value*>(this)->at(index);
+  }
+  Value& at(size_t index) {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_array()) return array_->at(index);
+    if (is_object()) return object_->at(index);
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+
+  template <typename T>
+  T get(const std::string & key, T default_value) const {
+    if (!contains(key)) return default_value;
+    return at(key).get<T>();
+  }
+
+  template <typename T>
+  T get() const {
+    if (is_primitive()) return primitive_.get<T>();
+    throw std::runtime_error("get<T> not defined for this value type: " + dump());
+  }
+
+  std::string dump(int indent=-1, bool to_json=false) const {
+    std::ostringstream out;
+    dump(out, indent, 0, to_json);
+    return out.str();
+  }
+
+  Value operator-() const {
+      if (is_number_integer())
+        return -get<int64_t>();
+      else
+        return -get<double>();
+  }
+  std::string to_str() const {
+    if (is_string()) return get<std::string>();
+    if (is_number_integer()) return std::to_string(get<int64_t>());
+    if (is_number_float()) return std::to_string(get<double>());
+    if (is_boolean()) return get<bool>() ? "True" : "False";
+    if (is_null()) return "None";
+    return dump();
+  }
+  Value operator+(const Value& rhs) const {
+      if (is_string() || rhs.is_string()) {
+        return to_str() + rhs.to_str();
+      } else if (is_number_integer() && rhs.is_number_integer()) {
+        return get<int64_t>() + rhs.get<int64_t>();
+      } else if (is_array() && rhs.is_array()) {
+        auto res = Value::array();
+        for (const auto& item : *array_) res.push_back(item);
+        for (const auto& item : *rhs.array_) res.push_back(item);
+        return res;
+      } else {
+        return get<double>() + rhs.get<double>();
+      }
+  }
+  Value operator-(const Value& rhs) const {
+      if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() - rhs.get<int64_t>();
+      else
+        return get<double>() - rhs.get<double>();
+  }
+  Value operator*(const Value& rhs) const {
+      if (is_string() && rhs.is_number_integer()) {
+        std::ostringstream out;
+        for (int64_t i = 0, n = rhs.get<int64_t>(); i < n; ++i) {
+          out << to_str();
+        }
+        return out.str();
+      }
+      else if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() * rhs.get<int64_t>();
+      else
+        return get<double>() * rhs.get<double>();
+  }
+  Value operator/(const Value& rhs) const {
+      if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() / rhs.get<int64_t>();
+      else
+        return get<double>() / rhs.get<double>();
+  }
+  Value operator%(const Value& rhs) const {
+    return get<int64_t>() % rhs.get<int64_t>();
+  }
+};
+
+struct ArgumentsValue {
+  std::vector<Value> args;
+  std::vector<std::pair<std::string, Value>> kwargs;
+
+  bool has_named(const std::string & name) {
+    for (const auto & p : kwargs) {
+      if (p.first == name) return true;
+    }
+    return false;
+  }
+
+  Value get_named(const std::string & name) {
+    for (const auto & [key, value] : kwargs) {
+      if (key == name) return value;
+    }
+    return Value();
+  }
+
+  bool empty() {
+    return args.empty() && kwargs.empty();
+  }
+
+  void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) {
+    if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
+      std::ostringstream out;
+      out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments";
+      throw std::runtime_error(out.str());
+    }
+  }
+};
+
+template <>
+inline json Value::get<json>() const {
+  if (is_primitive()) return primitive_;
+  if (is_null()) return json();
+  if (array_) {
+    std::vector<json> res;
+    for (const auto& item : *array_) {
+      res.push_back(item.get<json>());
+    }
+    return res;
+  }
+  if (object_) {
+    json res = json::object();
+    for (const auto& [key, value] : *object_) {
+      if (key.is_string()) {
+        res[key.get<std::string>()] = value.get<json>();
+      } else if (key.is_primitive()) {
+        res[key.dump()] = value.get<json>();
+      } else {
+        throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
+      }
+    }
+    if (is_callable()) {
+      res["__callable__"] = true;
+    }
+    return res;
+  }
+  throw std::runtime_error("get<json> not defined for this value type: " + dump());
+}
+
+} // namespace minja
+
+namespace std {
+  template <>
+  struct hash<minja::Value> {
+    size_t operator()(const minja::Value & v) const {
+      if (!v.is_hashable())
+        throw std::runtime_error("Unsupported type for hashing: " + v.dump());
+      return std::hash<json>()(v.get<json>());
+    }
+  };
+} // namespace std
+
+namespace minja {
+
+static std::string error_location_suffix(const std::string & source, size_t pos) {
+  auto get_line = [&](size_t line) {
+    auto start = source.begin();
+    for (size_t i = 1; i < line; ++i) {
+      start = std::find(start, source.end(), '\n') + 1;
+    }
+    auto end = std::find(start, source.end(), '\n');
+    return std::string(start, end);
+  };
+  auto start = source.begin();
+  auto end = source.end();
+  auto it = start + pos;
+  auto line = std::count(start, it, '\n') + 1;
+  auto max_line = std::count(start, end, '\n') + 1;
+  auto col = pos - std::string(start, it).rfind('\n');
+  std::ostringstream out;
+  out << " at row " << line << ", column " << col << ":\n";
+  if (line > 1) out << get_line(line - 1) << "\n";
+  out << get_line(line) << "\n";
+  out << std::string(col - 1, ' ') << "^\n";
+  if (line < max_line) out << get_line(line + 1) << "\n";
+
+  return out.str();
+}
+
+class Context : public std::enable_shared_from_this<Context> {
+  protected:
+    Value values_;
+    std::shared_ptr<Context> parent_;
+  public:
+    Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
+        if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
+    }
+    virtual ~Context() {}
+
+    static std::shared_ptr<Context> builtins();
+    static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
+
+    std::vector<Value> keys() {
+        return values_.keys();
+    }
+    virtual Value get(const Value & key) {
+        if (values_.contains(key)) return values_.at(key);
+        if (parent_) return parent_->get(key);
+        return Value();
+    }
+    virtual Value & at(const Value & key) {
+        if (values_.contains(key)) return values_.at(key);
+        if (parent_) return parent_->at(key);
+        throw std::runtime_error("Undefined variable: " + key.dump());
+    }
+    virtual bool contains(const Value & key) {
+        if (values_.contains(key)) return true;
+        if (parent_) return parent_->contains(key);
+        return false;
+    }
+    virtual void set(const Value & key, const Value & value) {
+        values_.set(key, value);
+    }
+};
+
+struct Location {
+    std::shared_ptr<std::string> source;
+    size_t pos;
+};
+
+class Expression {
+protected:
+    virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
+public:
+    using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
+
+    Location location;
+
+    Expression(const Location & location) : location(location) {}
+    virtual ~Expression() = default;
+
+    Value evaluate(const std::shared_ptr<Context> & context) const {
+        try {
+            return do_evaluate(context);
+        } catch (const std::exception & e) {
+            std::ostringstream out;
+            out << e.what();
+            if (location.source) out << error_location_suffix(*location.source, location.pos);
+            throw std::runtime_error(out.str());
+        }
+    }
+};
+
+class VariableExpr : public Expression {
+    std::string name;
+public:
+    VariableExpr(const Location & location, const std::string& n)
+      : Expression(location), name(n) {}
+    std::string get_name() const { return name; }
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!context->contains(name)) {
+            return Value();
+        }
+        return context->at(name);
+    }
+};
+
+static void destructuring_assign(const std::vector<std::string> & var_names, const std::shared_ptr<Context> & context, Value& item) {
+  if (var_names.size() == 1) {
+      Value name(var_names[0]);
+      context->set(name, item);
+  } else {
+      if (!item.is_array() || item.size() != var_names.size()) {
+          throw std::runtime_error("Mismatched number of variables and items in destructuring assignment");
+      }
+      for (size_t i = 0; i < var_names.size(); ++i) {
+          context->set(var_names[i], item.at(i));
+      }
+  }
+}
+
+enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
+
+class TemplateToken {
+public:
+    enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
+
+    static std::string typeToString(Type t) {
+        switch (t) {
+            case Type::Text: return "text";
+            case Type::Expression: return "expression";
+            case Type::If: return "if";
+            case Type::Else: return "else";
+            case Type::Elif: return "elif";
+            case Type::EndIf: return "endif";
+            case Type::For: return "for";
+            case Type::EndFor: return "endfor";
+            case Type::Set: return "set";
+            case Type::EndSet: return "endset";
+            case Type::Comment: return "comment";
+            case Type::Macro: return "macro";
+            case Type::EndMacro: return "endmacro";
+            case Type::Filter: return "filter";
+            case Type::EndFilter: return "endfilter";
+            case Type::Generation: return "generation";
+            case Type::EndGeneration: return "endgeneration";
+            case Type::Break: return "break";
+            case Type::Continue: return "continue";
+        }
+        return "Unknown";
+    }
+
+    TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {}
+    virtual ~TemplateToken() = default;
+
+    Type type;
+    Location location;
+    SpaceHandling pre_space = SpaceHandling::Keep;
+    SpaceHandling post_space = SpaceHandling::Keep;
+};
+
+struct TextTemplateToken : public TemplateToken {
+    std::string text;
+    TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {}
+};
+
+struct ExpressionTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> expr;
+    ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
+};
+
+struct IfTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> condition;
+    IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElifTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> condition;
+    ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElseTemplateToken : public TemplateToken {
+    ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {}
+};
+
+struct EndIfTemplateToken : public TemplateToken {
+    EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
+};
+
+struct MacroTemplateToken : public TemplateToken {
+    std::shared_ptr<VariableExpr> name;
+    Expression::Parameters params;
+    MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
+      : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
+};
+
+struct EndMacroTemplateToken : public TemplateToken {
+    EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {}
+};
+
+struct FilterTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> filter;
+    FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
+      : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {}
+};
+
+struct EndFilterTemplateToken : public TemplateToken {
+    EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {}
+};
+
+struct ForTemplateToken : public TemplateToken {
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> iterable;
+    std::shared_ptr<Expression> condition;
+    bool recursive;
+    ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
+      std::shared_ptr<Expression> && c, bool r)
+      : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
+};
+
+struct EndForTemplateToken : public TemplateToken {
+    EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
+};
+
+struct GenerationTemplateToken : public TemplateToken {
+    GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
+};
+
+struct EndGenerationTemplateToken : public TemplateToken {
+    EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
+};
+
+struct SetTemplateToken : public TemplateToken {
+    std::string ns;
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> value;
+    SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+      : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
+};
+
+struct EndSetTemplateToken : public TemplateToken {
+    EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {}
+};
+
+struct CommentTemplateToken : public TemplateToken {
+    std::string text;
+    CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
+};
+
+enum class LoopControlType { Break, Continue };
+
+class LoopControlException : public std::runtime_error {
+public:
+    LoopControlType control_type;
+    LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
+    LoopControlException(LoopControlType control_type)
+      : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
+        control_type(control_type) {}
+};
+
+struct LoopControlTemplateToken : public TemplateToken {
+    LoopControlType control_type;
+    LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
+};
+
+class TemplateNode {
+    Location location_;
+protected:
+    virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
+
+public:
+    TemplateNode(const Location & location) : location_(location) {}
+    void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
+        try {
+            do_render(out, context);
+        } catch (const LoopControlException & e) {
+            // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
+            std::ostringstream err;
+            err << e.what();
+            if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
+            throw LoopControlException(err.str(), e.control_type);
+        } catch (const std::exception & e) {
+            std::ostringstream err;
+            err << e.what();
+            if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
+            throw std::runtime_error(err.str());
+        }
+    }
+    const Location & location() const { return location_; }
+    virtual ~TemplateNode() = default;
+    std::string render(const std::shared_ptr<Context> & context) const {
+        std::ostringstream out;
+        render(out, context);
+        return out.str();
+    }
+};
+
+class SequenceNode : public TemplateNode {
+    std::vector<std::shared_ptr<TemplateNode>> children;
+public:
+    SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
+      : TemplateNode(location), children(std::move(c)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+        for (const auto& child : children) child->render(out, context);
+    }
+};
+
+class TextNode : public TemplateNode {
+    std::string text;
+public:
+    TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
+      out << text;
+    }
+};
+
+class ExpressionNode : public TemplateNode {
+    std::shared_ptr<Expression> expr;
+public:
+    ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
+      auto result = expr->evaluate(context);
+      if (result.is_string()) {
+          out << result.get<std::string>();
+      } else if (result.is_boolean()) {
+          out << (result.get<bool>() ? "True" : "False");
+      } else if (!result.is_null()) {
+          out << result.dump();
+      }
+  }
+};
+
+class IfNode : public TemplateNode {
+    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+public:
+    IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
+        : TemplateNode(location), cascade(std::move(c)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      for (const auto& branch : cascade) {
+          auto enter_branch = true;
+          if (branch.first) {
+            enter_branch = branch.first->evaluate(context).to_bool();
+          }
+          if (enter_branch) {
+            if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
+              branch.second->render(out, context);
+              return;
+          }
+      }
+    }
+};
+
+class LoopControlNode : public TemplateNode {
+    LoopControlType control_type_;
+  public:
+    LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
+      throw LoopControlException(control_type_);
+    }
+};
+
+class ForNode : public TemplateNode {
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> iterable;
+    std::shared_ptr<Expression> condition;
+    std::shared_ptr<TemplateNode> body;
+    bool recursive;
+    std::shared_ptr<TemplateNode> else_body;
+public:
+    ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
+      std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
+            : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
+
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
+      if (!iterable) throw std::runtime_error("ForNode.iterable is null");
+      if (!body) throw std::runtime_error("ForNode.body is null");
+
+      auto iterable_value = iterable->evaluate(context);
+      Value::CallableType loop_function;
+
+      std::function<void(Value&)> visit = [&](Value& iter) {
+          auto filtered_items = Value::array();
+          if (!iter.is_null()) {
+            if (!iterable_value.is_iterable()) {
+              throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump());
+            }
+            iterable_value.for_each([&](Value & item) {
+                destructuring_assign(var_names, context, item);
+                if (!condition || condition->evaluate(context).to_bool()) {
+                  filtered_items.push_back(item);
+                }
+            });
+          }
+          if (filtered_items.empty()) {
+            if (else_body) {
+              else_body->render(out, context);
+            }
+          } else {
+              auto loop = recursive ? Value::callable(loop_function) : Value::object();
+              loop.set("length", (int64_t) filtered_items.size());
+
+              size_t cycle_index = 0;
+              loop.set("cycle", Value::callable([&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+                  if (args.args.empty() || !args.kwargs.empty()) {
+                      throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg");
+                  }
+                  auto item = args.args[cycle_index];
+                  cycle_index = (cycle_index + 1) % args.args.size();
+                  return item;
+              }));
+              auto loop_context = Context::make(Value::object(), context);
+              loop_context->set("loop", loop);
+              for (size_t i = 0, n = filtered_items.size(); i < n; ++i) {
+                  auto & item = filtered_items.at(i);
+                  destructuring_assign(var_names, loop_context, item);
+                  loop.set("index", (int64_t) i + 1);
+                  loop.set("index0", (int64_t) i);
+                  loop.set("revindex", (int64_t) (n - i));
+                  loop.set("revindex0", (int64_t) (n - i - 1));
+                  loop.set("length", (int64_t) n);
+                  loop.set("first", i == 0);
+                  loop.set("last", i == (n - 1));
+                  loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
+                  loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
+                  try {
+                      body->render(out, loop_context);
+                  } catch (const LoopControlException & e) {
+                      if (e.control_type == LoopControlType::Break) break;
+                      if (e.control_type == LoopControlType::Continue) continue;
+                  }
+              }
+          }
+      };
+
+      if (recursive) {
+        loop_function = [&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+            if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) {
+                throw std::runtime_error("loop() expects exactly 1 positional iterable argument");
+            }
+            auto & items = args.args[0];
+            visit(items);
+            return Value();
+        };
+      }
+
+      visit(iterable_value);
+  }
+};
+
+class MacroNode : public TemplateNode {
+    std::shared_ptr<VariableExpr> name;
+    Expression::Parameters params;
+    std::shared_ptr<TemplateNode> body;
+    std::unordered_map<std::string, size_t> named_param_positions;
+public:
+    MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
+        : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
+        for (size_t i = 0; i < params.size(); ++i) {
+          const auto & name = params[i].first;
+          if (!name.empty()) {
+            named_param_positions[name] = i;
+          }
+        }
+    }
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
+        if (!name) throw std::runtime_error("MacroNode.name is null");
+        if (!body) throw std::runtime_error("MacroNode.body is null");
+        auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+            auto call_context = macro_context;
+            std::vector<bool> param_set(params.size(), false);
+            for (size_t i = 0, n = args.args.size(); i < n; i++) {
+                auto & arg = args.args[i];
+                if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
+                param_set[i] = true;
+                auto & param_name = params[i].first;
+                call_context->set(param_name, arg);
+            }
+            for (auto & [arg_name, value] : args.kwargs) {
+                auto it = named_param_positions.find(arg_name);
+                if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
+
+                call_context->set(arg_name, value);
+                param_set[it->second] = true;
+            }
+            // Set default values for parameters that were not passed
+            for (size_t i = 0, n = params.size(); i < n; i++) {
+                if (!param_set[i] && params[i].second != nullptr) {
+                    auto val = params[i].second->evaluate(context);
+                    call_context->set(params[i].first, val);
+                }
+            }
+            return body->render(call_context);
+        });
+        macro_context->set(name->get_name(), callable);
+    }
+};
+
+class FilterNode : public TemplateNode {
+    std::shared_ptr<Expression> filter;
+    std::shared_ptr<TemplateNode> body;
+
+public:
+    FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
+        : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {}
+
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+        if (!filter) throw std::runtime_error("FilterNode.filter is null");
+        if (!body) throw std::runtime_error("FilterNode.body is null");
+        auto filter_value = filter->evaluate(context);
+        if (!filter_value.is_callable()) {
+            throw std::runtime_error("Filter must be a callable: " + filter_value.dump());
+        }
+        std::string rendered_body = body->render(context);
+
+        ArgumentsValue filter_args = {{Value(rendered_body)}, {}};
+        auto result = filter_value.call(context, filter_args);
+        out << result.to_str();
+    }
+};
+
+class SetNode : public TemplateNode {
+    std::string ns;
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> value;
+public:
+    SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+        : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+      if (!value) throw std::runtime_error("SetNode.value is null");
+      if (!ns.empty()) {
+        if (var_names.size() != 1) {
+          throw std::runtime_error("Namespaced set only supports a single variable name");
+        }
+        auto & name = var_names[0];
+        auto ns_value = context->get(ns);
+        if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
+        ns_value.set(name, this->value->evaluate(context));
+      } else {
+        auto val = value->evaluate(context);
+        destructuring_assign(var_names, context, val);
+      }
+    }
+};
+
+class SetTemplateNode : public TemplateNode {
+    std::string name;
+    std::shared_ptr<TemplateNode> template_value;
+public:
+    SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
+        : TemplateNode(location), name(name), template_value(std::move(tv)) {}
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+      if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
+      Value value { template_value->render(context) };
+      context->set(name, value);
+    }
+};
+
+class IfExpr : public Expression {
+    std::shared_ptr<Expression> condition;
+    std::shared_ptr<Expression> then_expr;
+    std::shared_ptr<Expression> else_expr;
+public:
+    IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
+        : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+      if (!condition) throw std::runtime_error("IfExpr.condition is null");
+      if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
+      if (condition->evaluate(context).to_bool()) {
+        return then_expr->evaluate(context);
+      }
+      if (else_expr) {
+        return else_expr->evaluate(context);
+      }
+      return nullptr;
+    }
+};
+
+class LiteralExpr : public Expression {
+    Value value;
+public:
+    LiteralExpr(const Location & location, const Value& v)
+      : Expression(location), value(v) {}
+    Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
+};
+
+class ArrayExpr : public Expression {
+    std::vector<std::shared_ptr<Expression>> elements;
+public:
+    ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
+      : Expression(location), elements(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        auto result = Value::array();
+        for (const auto& e : elements) {
+            if (!e) throw std::runtime_error("Array element is null");
+            result.push_back(e->evaluate(context));
+        }
+        return result;
+    }
+};
+
+class DictExpr : public Expression {
+    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+public:
+    DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
+      : Expression(location), elements(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        auto result = Value::object();
+        for (const auto& [key, value] : elements) {
+            if (!key) throw std::runtime_error("Dict key is null");
+            if (!value) throw std::runtime_error("Dict value is null");
+            result.set(key->evaluate(context), value->evaluate(context));
+        }
+        return result;
+    }
+};
+
+class SliceExpr : public Expression {
+public:
+    std::shared_ptr<Expression> start, end;
+    SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
+      : Expression(location), start(std::move(s)), end(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> &) const override {
+        throw std::runtime_error("SliceExpr not implemented");
+    }
+};
+
+class SubscriptExpr : public Expression {
+    std::shared_ptr<Expression> base;
+    std::shared_ptr<Expression> index;
+public:
+    SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
+        : Expression(location), base(std::move(b)), index(std::move(i)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!base) throw std::runtime_error("SubscriptExpr.base is null");
+        if (!index) throw std::runtime_error("SubscriptExpr.index is null");
+        auto target_value = base->evaluate(context);
+        if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
+          auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
+          auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
+          if (target_value.is_string()) {
+            std::string s = target_value.get<std::string>();
+            if (start < 0) start = s.size() + start;
+            if (end < 0) end = s.size() + end;
+            return s.substr(start, end - start);
+          } else if (target_value.is_array()) {
+            if (start < 0) start = target_value.size() + start;
+            if (end < 0) end = target_value.size() + end;
+            auto result = Value::array();
+            for (auto i = start; i < end; ++i) {
+              result.push_back(target_value.at(i));
+            }
+            return result;
+          } else {
+            throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings");
+          }
+        } else {
+          auto index_value = index->evaluate(context);
+          if (target_value.is_null()) {
+            if (auto t = dynamic_cast<VariableExpr*>(base.get())) {
+              throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined"));
+            }
+            throw std::runtime_error("Trying to access property '" +  index_value.dump() + "' on null!");
+          }
+          return target_value.get(index_value);
+        }
+    }
+};
+
+class UnaryOpExpr : public Expression {
+public:
+    enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
+    std::shared_ptr<Expression> expr;
+    Op op;
+    UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
+      : Expression(location), expr(std::move(e)), op(o) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
+        auto e = expr->evaluate(context);
+        switch (op) {
+            case Op::Plus: return e;
+            case Op::Minus: return -e;
+            case Op::LogicalNot: return !e.to_bool();
+            case Op::Expansion:
+            case Op::ExpansionDict:
+                throw std::runtime_error("Expansion operator is only supported in function calls and collections");
+
+        }
+        throw std::runtime_error("Unknown unary operator");
+    }
+};
+
+class BinaryOpExpr : public Expression {
+public:
+    enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
+private:
+    std::shared_ptr<Expression> left;
+    std::shared_ptr<Expression> right;
+    Op op;
+public:
+    BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
+        : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
+        if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
+        auto l = left->evaluate(context);
+
+        auto do_eval = [&](const Value & l) -> Value {
+          if (op == Op::Is || op == Op::IsNot) {
+            auto t = dynamic_cast<VariableExpr*>(right.get());
+            if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable");
+
+            auto eval = [&]() {
+              const auto & name = t->get_name();
+              if (name == "none") return l.is_null();
+              if (name == "boolean") return l.is_boolean();
+              if (name == "integer") return l.is_number_integer();
+              if (name == "float") return l.is_number_float();
+              if (name == "number") return l.is_number();
+              if (name == "string") return l.is_string();
+              if (name == "mapping") return l.is_object();
+              if (name == "iterable") return l.is_iterable();
+              if (name == "sequence") return l.is_array();
+              if (name == "defined") return !l.is_null();
+              throw std::runtime_error("Unknown type for 'is' operator: " + name);
+            };
+            auto value = eval();
+            return Value(op == Op::Is ? value : !value);
+          }
+
+          if (op == Op::And) {
+            if (!l.to_bool()) return Value(false);
+            return right->evaluate(context).to_bool();
+          } else if (op == Op::Or) {
+            if (l.to_bool()) return l;
+            return right->evaluate(context);
+          }
+
+          auto r = right->evaluate(context);
+          switch (op) {
+              case Op::StrConcat: return l.to_str() + r.to_str();
+              case Op::Add:       return l + r;
+              case Op::Sub:       return l - r;
+              case Op::Mul:       return l * r;
+              case Op::Div:       return l / r;
+              case Op::MulMul:    return std::pow(l.get<double>(), r.get<double>());
+              case Op::DivDiv:    return l.get<int64_t>() / r.get<int64_t>();
+              case Op::Mod:       return l.get<int64_t>() % r.get<int64_t>();
+              case Op::Eq:        return l == r;
+              case Op::Ne:        return l != r;
+              case Op::Lt:        return l < r;
+              case Op::Gt:        return l > r;
+              case Op::Le:        return l <= r;
+              case Op::Ge:        return l >= r;
+              case Op::In:        return (r.is_array() || r.is_object()) && r.contains(l);
+              case Op::NotIn:     return !(r.is_array() && r.contains(l));
+              default:            break;
+          }
+          throw std::runtime_error("Unknown binary operator");
+        };
+
+        if (l.is_callable()) {
+          return Value::callable([l, do_eval](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+            auto ll = l.call(context, args);
+            return do_eval(ll); //args[0].second);
+          });
+        } else {
+          return do_eval(l);
+        }
+    }
+};
+
+struct ArgumentsExpression {
+    std::vector<std::shared_ptr<Expression>> args;
+    std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
+
+    ArgumentsValue evaluate(const std::shared_ptr<Context> & context) const {
+        ArgumentsValue vargs;
+        for (const auto& arg : this->args) {
+            if (auto un_expr = std::dynamic_pointer_cast<UnaryOpExpr>(arg)) {
+                if (un_expr->op == UnaryOpExpr::Op::Expansion) {
+                    auto array = un_expr->expr->evaluate(context);
+                    if (!array.is_array()) {
+                        throw std::runtime_error("Expansion operator only supported on arrays");
+                    }
+                    array.for_each([&](Value & value) {
+                        vargs.args.push_back(value);
+                    });
+                    continue;
+                } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) {
+                    auto dict = un_expr->expr->evaluate(context);
+                    if (!dict.is_object()) {
+                        throw std::runtime_error("ExpansionDict operator only supported on objects");
+                    }
+                    dict.for_each([&](const Value & key) {
+                        vargs.kwargs.push_back({key.get<std::string>(), dict.at(key)});
+                    });
+                    continue;
+                }
+            }
+            vargs.args.push_back(arg->evaluate(context));
+        }
+        for (const auto& [name, value] : this->kwargs) {
+            vargs.kwargs.push_back({name, value->evaluate(context)});
+        }
+        return vargs;
+    }
+};
+
+static std::string strip(const std::string & s) {
+  auto start = s.find_first_not_of(" \t\n\r");
+  if (start == std::string::npos) return "";
+  auto end = s.find_last_not_of(" \t\n\r");
+  return s.substr(start, end - start + 1);
+}
+
+static std::string capitalize(const std::string & s) {
+  if (s.empty()) return s;
+  auto result = s;
+  result[0] = std::toupper(result[0]);
+  return result;
+}
+
+static std::string html_escape(const std::string & s) {
+  std::string result;
+  result.reserve(s.size());
+  for (const auto & c : s) {
+    switch (c) {
+      case '&': result += "&amp;"; break;
+      case '<': result += "&lt;"; break;
+      case '>': result += "&gt;"; break;
+      case '"': result += "&#34;"; break;
+      case '\'': result += "&apos;"; break;
+      default: result += c; break;
+    }
+  }
+  return result;
+}
+
+class MethodCallExpr : public Expression {
+    std::shared_ptr<Expression> object;
+    std::shared_ptr<VariableExpr> method;
+    ArgumentsExpression args;
+public:
+    MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
+        : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!object) throw std::runtime_error("MethodCallExpr.object is null");
+        if (!method) throw std::runtime_error("MethodCallExpr.method is null");
+        auto obj = object->evaluate(context);
+        auto vargs = args.evaluate(context);
+        if (obj.is_null()) {
+          throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null");
+        }
+        if (obj.is_array()) {
+          if (method->get_name() == "append") {
+              vargs.expectArgs("append method", {1, 1}, {0, 0});
+              obj.push_back(vargs.args[0]);
+              return Value();
+          } else if (method->get_name() == "pop") {
+              vargs.expectArgs("pop method", {0, 1}, {0, 0});
+              return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]);
+          } else if (method->get_name() == "insert") {
+              vargs.expectArgs("insert method", {2, 2}, {0, 0});
+              auto index = vargs.args[0].get<int64_t>();
+              if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method");
+              obj.insert(index, vargs.args[1]);
+              return Value();
+          }
+        } else if (obj.is_object()) {
+          if (method->get_name() == "items") {
+            vargs.expectArgs("items method", {0, 0}, {0, 0});
+            auto result = Value::array();
+            for (const auto& key : obj.keys()) {
+              result.push_back(Value::array({key, obj.at(key)}));
+            }
+            return result;
+          } else if (method->get_name() == "pop") {
+            vargs.expectArgs("pop method", {1, 1}, {0, 0});
+            return obj.pop(vargs.args[0]);
+          } else if (method->get_name() == "get") {
+            vargs.expectArgs("get method", {1, 2}, {0, 0});
+            auto key = vargs.args[0];
+            if (vargs.args.size() == 1) {
+              return obj.contains(key) ? obj.at(key) : Value();
+            } else {
+              return obj.contains(key) ? obj.at(key) : vargs.args[1];
+            }
+          } else if (obj.contains(method->get_name())) {
+            auto callable = obj.at(method->get_name());
+            if (!callable.is_callable()) {
+              throw std::runtime_error("Property '" + method->get_name() + "' is not callable");
+            }
+            return callable.call(context, vargs);
+          }
+        } else if (obj.is_string()) {
+          auto str = obj.get<std::string>();
+          if (method->get_name() == "strip") {
+            vargs.expectArgs("strip method", {0, 0}, {0, 0});
+            return Value(strip(str));
+          } else if (method->get_name() == "capitalize") {
+            vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
+            return Value(capitalize(str));
+          } else if (method->get_name() == "endswith") {
+            vargs.expectArgs("endswith method", {1, 1}, {0, 0});
+            auto suffix = vargs.args[0].get<std::string>();
+            return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
+          } else if (method->get_name() == "title") {
+            vargs.expectArgs("title method", {0, 0}, {0, 0});
+            auto res = str;
+            for (size_t i = 0, n = res.size(); i < n; ++i) {
+              if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]);
+              else res[i] = std::tolower(res[i]);
+            }
+            return res;
+          }
+        }
+        throw std::runtime_error("Unknown method: " + method->get_name());
+    }
+};
+
+class CallExpr : public Expression {
+public:
+    std::shared_ptr<Expression> object;
+    ArgumentsExpression args;
+    CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
+        : Expression(location), object(std::move(obj)), args(std::move(a)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!object) throw std::runtime_error("CallExpr.object is null");
+        auto obj = object->evaluate(context);
+        if (!obj.is_callable()) {
+          throw std::runtime_error("Object is not callable: " + obj.dump(2));
+        }
+        auto vargs = args.evaluate(context);
+        return obj.call(context, vargs);
+    }
+};
+
+class FilterExpr : public Expression {
+    std::vector<std::shared_ptr<Expression>> parts;
+public:
+    FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
+      : Expression(location), parts(std::move(p)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        Value result;
+        bool first = true;
+        for (const auto& part : parts) {
+          if (!part) throw std::runtime_error("FilterExpr.part is null");
+          if (first) {
+            first = false;
+            result = part->evaluate(context);
+          } else {
+            if (auto ce = dynamic_cast<CallExpr*>(part.get())) {
+              auto target = ce->object->evaluate(context);
+              ArgumentsValue args = ce->args.evaluate(context);
+              args.args.insert(args.args.begin(), result);
+              result = target.call(context, args);
+            } else {
+              auto callable = part->evaluate(context);
+              ArgumentsValue args;
+              args.args.insert(args.args.begin(), result);
+              result = callable.call(context, args);
+            }
+          }
+        }
+        return result;
+    }
+
+    void prepend(std::shared_ptr<Expression> && e) {
+        parts.insert(parts.begin(), std::move(e));
+    }
+};
+
+class Parser {
+private:
+    using CharIterator = std::string::const_iterator;
+
+    std::shared_ptr<std::string> template_str;
+    CharIterator start, end, it;
+    Options options;
+
+    Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
+      if (!template_str) throw std::runtime_error("Template string is null");
+      start = it = this->template_str->begin();
+      end = this->template_str->end();
+    }
+
+    bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) {
+      if (space_handling == SpaceHandling::Strip) {
+        while (it != end && std::isspace(*it)) ++it;
+      }
+      return true;
+    }
+
+    std::unique_ptr<std::string> parseString() {
+      auto doParse = [&](char quote) -> std::unique_ptr<std::string> {
+        if (it == end || *it != quote) return nullptr;
+        std::string result;
+        bool escape = false;
+        for (++it; it != end; ++it) {
+          if (escape) {
+            escape = false;
+            switch (*it) {
+              case 'n': result += '\n'; break;
+              case 'r': result += '\r'; break;
+              case 't': result += '\t'; break;
+              case 'b': result += '\b'; break;
+              case 'f': result += '\f'; break;
+              case '\\': result += '\\'; break;
+              default:
+                if (*it == quote) {
+                  result += quote;
+                } else {
+                  result += *it;
+                }
+                break;
+            }
+          } else if (*it == '\\') {
+            escape = true;
+          } else if (*it == quote) {
+              ++it;
+            return std::make_unique<std::string>(std::move(result));
+          } else {
+            result += *it;
+          }
+        }
+        return nullptr;
+      };
+
+      consumeSpaces();
+      if (it == end) return nullptr;
+      if (*it == '"') return doParse('"');
+      if (*it == '\'') return doParse('\'');
+      return nullptr;
+    }
+
+    json parseNumber(CharIterator& it, const CharIterator& end) {
+        auto before = it;
+        consumeSpaces();
+        auto start = it;
+        bool hasDecimal = false;
+        bool hasExponent = false;
+
+        if (it != end && (*it == '-' || *it == '+')) ++it;
+
+        while (it != end) {
+          if (std::isdigit(*it)) {
+            ++it;
+          } else if (*it == '.') {
+            if (hasDecimal) throw std::runtime_error("Multiple decimal points");
+            hasDecimal = true;
+            ++it;
+          } else if (it != start && (*it == 'e' || *it == 'E')) {
+            if (hasExponent) throw std::runtime_error("Multiple exponents");
+            hasExponent = true;
+            ++it;
+          } else {
+            break;
+          }
+        }
+        if (start == it) {
+          it = before;
+          return json(); // No valid characters found
+        }
+
+        std::string str(start, it);
+        try {
+          return json::parse(str);
+        } catch (json::parse_error& e) {
+          throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")");
+          return json();
+        }
+    }
+
+    /** integer, float, bool, string */
+    std::shared_ptr<Value> parseConstant() {
+      auto start = it;
+      consumeSpaces();
+      if (it == end) return nullptr;
+      if (*it == '"' || *it == '\'') {
+        auto str = parseString();
+        if (str) return std::make_shared<Value>(*str);
+      }
+      static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
+      auto token = consumeToken(prim_tok);
+      if (!token.empty()) {
+        if (token == "true" || token == "True") return std::make_shared<Value>(true);
+        if (token == "false" || token == "False") return std::make_shared<Value>(false);
+        if (token == "None") return std::make_shared<Value>(nullptr);
+        throw std::runtime_error("Unknown constant token: " + token);
+      }
+
+      auto number = parseNumber(it, end);
+      if (!number.is_null()) return std::make_shared<Value>(number);
+
+      it = start;
+      return nullptr;
+    }
+
+    class expression_parsing_error : public std::runtime_error {
+        const CharIterator it;
+      public:
+        expression_parsing_error(const std::string & message, const CharIterator & it)
+            : std::runtime_error(message), it(it) {}
+        size_t get_pos(const CharIterator & begin) const {
+            return std::distance(begin, it);
+      }
+    };
+
+    bool peekSymbols(const std::vector<std::string> & symbols) const {
+        for (const auto & symbol : symbols) {
+            if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    std::vector<std::string> consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        std::smatch match;
+        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+            it += match[0].length();
+            std::vector<std::string> ret;
+            for (size_t i = 0, n = match.size(); i < n; ++i) {
+                ret.push_back(match[i].str());
+            }
+            return ret;
+        }
+        it = start;
+        return {};
+    }
+    std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        std::smatch match;
+        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+            it += match[0].length();
+            return match[0].str();
+        }
+        it = start;
+        return "";
+    }
+
+    std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) {
+            it += token.size();
+            return token;
+        }
+        it = start;
+        return "";
+    }
+
+    std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
+        auto left = parseLogicalOr();
+        if (it == end) return left;
+
+        if (!allow_if_expr) return left;
+
+        static std::regex if_tok(R"(if\b)");
+        if (consumeToken(if_tok).empty()) {
+          return left;
+        }
+
+        auto location = get_location();
+        auto [condition, else_expr] = parseIfExpression();
+        return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
+    }
+
+    Location get_location() const {
+        return {template_str, (size_t) std::distance(start, it)};
+    }
+
+    std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
+        auto condition = parseLogicalOr();
+        if (!condition) throw std::runtime_error("Expected condition expression");
+
+        static std::regex else_tok(R"(else\b)");
+        std::shared_ptr<Expression> else_expr;
+        if (!consumeToken(else_tok).empty()) {
+          else_expr = parseExpression();
+          if (!else_expr) throw std::runtime_error("Expected 'else' expression");
+        }
+        return std::pair(std::move(condition), std::move(else_expr));
+    }
+
+    std::shared_ptr<Expression> parseLogicalOr() {
+        auto left = parseLogicalAnd();
+        if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
+
+        static std::regex or_tok(R"(or\b)");
+        auto location = get_location();
+        while (!consumeToken(or_tok).empty()) {
+            auto right = parseLogicalAnd();
+            if (!right) throw std::runtime_error("Expected right side of 'or' expression");
+            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseLogicalNot() {
+        static std::regex not_tok(R"(not\b)");
+        auto location = get_location();
+
+        if (!consumeToken(not_tok).empty()) {
+          auto sub = parseLogicalNot();
+          if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
+          return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
+        }
+        return parseLogicalCompare();
+    }
+
+    std::shared_ptr<Expression> parseLogicalAnd() {
+        auto left = parseLogicalNot();
+        if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
+
+        static std::regex and_tok(R"(and\b)");
+        auto location = get_location();
+        while (!consumeToken(and_tok).empty()) {
+            auto right = parseLogicalNot();
+            if (!right) throw std::runtime_error("Expected right side of 'and' expression");
+            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseLogicalCompare() {
+        auto left = parseStringConcat();
+        if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
+
+        static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
+        static std::regex not_tok(R"(not\b)");
+        std::string op_str;
+        while (!(op_str = consumeToken(compare_tok)).empty()) {
+            auto location = get_location();
+            if (op_str == "is") {
+              auto negated = !consumeToken(not_tok).empty();
+
+              auto identifier = parseIdentifier();
+              if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
+
+              return std::make_shared<BinaryOpExpr>(
+                  left->location,
+                  std::move(left), std::move(identifier),
+                  negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
+            }
+            auto right = parseStringConcat();
+            if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression");
+            BinaryOpExpr::Op op;
+            if (op_str == "==") op = BinaryOpExpr::Op::Eq;
+            else if (op_str == "!=") op = BinaryOpExpr::Op::Ne;
+            else if (op_str == "<") op = BinaryOpExpr::Op::Lt;
+            else if (op_str == ">") op = BinaryOpExpr::Op::Gt;
+            else if (op_str == "<=") op = BinaryOpExpr::Op::Le;
+            else if (op_str == ">=") op = BinaryOpExpr::Op::Ge;
+            else if (op_str == "in") op = BinaryOpExpr::Op::In;
+            else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
+            else throw std::runtime_error("Unknown comparison operator: " + op_str);
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+        return left;
+    }
+
+    Expression::Parameters parseParameters() {
+        consumeSpaces();
+        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
+
+        Expression::Parameters result;
+
+        while (it != end) {
+            if (!consumeToken(")").empty()) {
+                return result;
+            }
+            auto expr = parseExpression();
+            if (!expr) throw std::runtime_error("Expected expression in call args");
+
+            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+                if (!consumeToken("=").empty()) {
+                    auto value = parseExpression();
+                    if (!value) throw std::runtime_error("Expected expression in for named arg");
+                    result.emplace_back(ident->get_name(), std::move(value));
+                } else {
+                    result.emplace_back(ident->get_name(), nullptr);
+                }
+            } else {
+                result.emplace_back(std::string(), std::move(expr));
+            }
+            if (consumeToken(",").empty()) {
+              if (consumeToken(")").empty()) {
+                throw std::runtime_error("Expected closing parenthesis in call args");
+              }
+              return result;
+            }
+        }
+        throw std::runtime_error("Expected closing parenthesis in call args");
+    }
+
+    ArgumentsExpression parseCallArgs() {
+        consumeSpaces();
+        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
+
+        ArgumentsExpression result;
+
+        while (it != end) {
+            if (!consumeToken(")").empty()) {
+                return result;
+            }
+            auto expr = parseExpression();
+            if (!expr) throw std::runtime_error("Expected expression in call args");
+
+            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+                if (!consumeToken("=").empty()) {
+                    auto value = parseExpression();
+                    if (!value) throw std::runtime_error("Expected expression in for named arg");
+                    result.kwargs.emplace_back(ident->get_name(), std::move(value));
+                } else {
+                    result.args.emplace_back(std::move(expr));
+                }
+            } else {
+                result.args.emplace_back(std::move(expr));
+            }
+            if (consumeToken(",").empty()) {
+              if (consumeToken(")").empty()) {
+                throw std::runtime_error("Expected closing parenthesis in call args");
+              }
+              return result;
+            }
+        }
+        throw std::runtime_error("Expected closing parenthesis in call args");
+    }
+
+    std::shared_ptr<VariableExpr> parseIdentifier() {
+        static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
+        auto location = get_location();
+        auto ident = consumeToken(ident_regex);
+        if (ident.empty())
+          return nullptr;
+        return std::make_shared<VariableExpr>(location, ident);
+    }
+
+    std::shared_ptr<Expression> parseStringConcat() {
+        auto left = parseMathPow();
+        if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
+
+        static std::regex concat_tok(R"(~(?!\}))");
+        if (!consumeToken(concat_tok).empty()) {
+            auto right = parseLogicalAnd();
+            if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathPow() {
+        auto left = parseMathPlusMinus();
+        if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
+
+        while (!consumeToken("**").empty()) {
+            auto right = parseMathPlusMinus();
+            if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathPlusMinus() {
+        static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
+
+        auto left = parseMathMulDiv();
+        if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression");
+        std::string op_str;
+        while (!(op_str = consumeToken(plus_minus_tok)).empty()) {
+            auto right = parseMathMulDiv();
+            if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
+            auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathMulDiv() {
+        auto left = parseMathUnaryPlusMinus();
+        if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
+
+        static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))");
+        std::string op_str;
+        while (!(op_str = consumeToken(mul_div_tok)).empty()) {
+            auto right = parseMathUnaryPlusMinus();
+            if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
+            auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
+                : op_str == "**" ? BinaryOpExpr::Op::MulMul
+                : op_str == "/" ? BinaryOpExpr::Op::Div
+                : op_str == "//" ? BinaryOpExpr::Op::DivDiv
+                : BinaryOpExpr::Op::Mod;
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+
+        if (!consumeToken("|").empty()) {
+            auto expr = parseMathMulDiv();
+            if (auto filter = dynamic_cast<FilterExpr*>(expr.get())) {
+                filter->prepend(std::move(left));
+                return expr;
+            } else {
+                std::vector<std::shared_ptr<Expression>> parts;
+                parts.emplace_back(std::move(left));
+                parts.emplace_back(std::move(expr));
+                return std::make_shared<FilterExpr>(get_location(), std::move(parts));
+            }
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> call_func(const std::string & name, ArgumentsExpression && args) const {
+        return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
+    }
+
+    std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
+        static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
+        auto op_str = consumeToken(unary_plus_minus_tok);
+        auto expr = parseExpansion();
+        if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression");
+
+        if (!op_str.empty()) {
+            auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
+            return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
+        }
+        return expr;
+    }
+
+    std::shared_ptr<Expression> parseExpansion() {
+      static std::regex expansion_tok(R"(\*\*?)");
+      auto op_str = consumeToken(expansion_tok);
+      auto expr = parseValueExpression();
+      if (op_str.empty()) return expr;
+      if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression");
+      return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict);
+    }
+
+    std::shared_ptr<Expression> parseValueExpression() {
+      auto parseValue = [&]() -> std::shared_ptr<Expression> {
+        auto location = get_location();
+        auto constant = parseConstant();
+        if (constant) return std::make_shared<LiteralExpr>(location, *constant);
+
+        static std::regex null_regex(R"(null\b)");
+        if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
+
+        auto identifier = parseIdentifier();
+        if (identifier) return identifier;
+
+        auto braced = parseBracedExpressionOrArray();
+        if (braced) return braced;
+
+        auto array = parseArray();
+        if (array) return array;
+
+        auto dictionary = parseDictionary();
+        if (dictionary) return dictionary;
+
+        throw std::runtime_error("Expected value expression");
+      };
+
+      auto value = parseValue();
+
+      while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
+        if (!consumeToken("[").empty()) {
+            std::shared_ptr<Expression> index;
+            if (!consumeToken(":").empty()) {
+              auto slice_end = parseExpression();
+              index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
+            } else {
+              auto slice_start = parseExpression();
+              if (!consumeToken(":").empty()) {
+                consumeSpaces();
+                if (peekSymbols({ "]" })) {
+                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
+                } else {
+                  auto slice_end = parseExpression();
+                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
+                }
+              } else {
+                index = std::move(slice_start);
+              }
+            }
+            if (!index) throw std::runtime_error("Empty index in subscript");
+            if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
+
+            value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
+        } else if (!consumeToken(".").empty()) {
+            auto identifier = parseIdentifier();
+            if (!identifier) throw std::runtime_error("Expected identifier in subscript");
+
+            consumeSpaces();
+            if (peekSymbols({ "(" })) {
+              auto callParams = parseCallArgs();
+              value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
+            } else {
+              auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
+              value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
+            }
+        }
+        consumeSpaces();
+      }
+
+      if (peekSymbols({ "(" })) {
+        auto location = get_location();
+        auto callParams = parseCallArgs();
+        value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
+      }
+      return value;
+    }
+
+    std::shared_ptr<Expression> parseBracedExpressionOrArray() {
+        if (consumeToken("(").empty()) return nullptr;
+
+        auto expr = parseExpression();
+        if (!expr) throw std::runtime_error("Expected expression in braced expression");
+
+        if (!consumeToken(")").empty()) {
+            return expr;  // Drop the parentheses
+        }
+
+        std::vector<std::shared_ptr<Expression>> tuple;
+        tuple.emplace_back(std::move(expr));
+
+        while (it != end) {
+          if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple");
+          auto next = parseExpression();
+          if (!next) throw std::runtime_error("Expected expression in tuple");
+          tuple.push_back(std::move(next));
+
+          if (!consumeToken(")").empty()) {
+              return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
+          }
+        }
+        throw std::runtime_error("Expected closing parenthesis");
+    }
+
+    std::shared_ptr<Expression> parseArray() {
+        if (consumeToken("[").empty()) return nullptr;
+
+        std::vector<std::shared_ptr<Expression>> elements;
+        if (!consumeToken("]").empty()) {
+            return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+        }
+        auto first_expr = parseExpression();
+        if (!first_expr) throw std::runtime_error("Expected first expression in array");
+        elements.push_back(std::move(first_expr));
+
+        while (it != end) {
+            if (!consumeToken(",").empty()) {
+              auto expr = parseExpression();
+              if (!expr) throw std::runtime_error("Expected expression in array");
+              elements.push_back(std::move(expr));
+            } else if (!consumeToken("]").empty()) {
+                return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+            } else {
+                throw std::runtime_error("Expected comma or closing bracket in array");
+            }
+        }
+        throw std::runtime_error("Expected closing bracket");
+    }
+
+    std::shared_ptr<Expression> parseDictionary() {
+        if (consumeToken("{").empty()) return nullptr;
+
+        std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+        if (!consumeToken("}").empty()) {
+            return std::make_shared<DictExpr>(get_location(), std::move(elements));
+        }
+
+        auto parseKeyValuePair = [&]() {
+            auto key = parseExpression();
+            if (!key) throw std::runtime_error("Expected key in dictionary");
+            if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
+            auto value = parseExpression();
+            if (!value) throw std::runtime_error("Expected value in dictionary");
+            elements.emplace_back(std::pair(std::move(key), std::move(value)));
+        };
+
+        parseKeyValuePair();
+
+        while (it != end) {
+            if (!consumeToken(",").empty()) {
+                parseKeyValuePair();
+            } else if (!consumeToken("}").empty()) {
+                return std::make_shared<DictExpr>(get_location(), std::move(elements));
+            } else {
+                throw std::runtime_error("Expected comma or closing brace in dictionary");
+            }
+        }
+        throw std::runtime_error("Expected closing brace");
+    }
+
+    SpaceHandling parsePreSpace(const std::string& s) const {
+        if (s == "-")
+          return SpaceHandling::Strip;
+        return SpaceHandling::Keep;
+    }
+
+    SpaceHandling parsePostSpace(const std::string& s) const {
+        if (s == "-") return SpaceHandling::Strip;
+        return SpaceHandling::Keep;
+    }
+
+    using TemplateTokenVector = std::vector<std::unique_ptr<TemplateToken>>;
+    using TemplateTokenIterator = TemplateTokenVector::const_iterator;
+
+    std::vector<std::string> parseVarNames() {
+      static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
+
+      std::vector<std::string> group;
+      if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
+      std::vector<std::string> varnames;
+      std::istringstream iss(group[1]);
+      std::string varname;
+      while (std::getline(iss, varname, ',')) {
+        varnames.push_back(strip(varname));
+      }
+      return varnames;
+    }
+
+    std::runtime_error unexpected(const TemplateToken & token) const {
+      return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type)
+        + error_location_suffix(*template_str, token.location.pos));
+    }
+    std::runtime_error unterminated(const TemplateToken & token) const {
+      return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type)
+        + error_location_suffix(*template_str, token.location.pos));
+    }
+
+    TemplateTokenVector tokenize() {
+      static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
+      static std::regex expr_open_regex(R"(\{\{([-~])?)");
+      static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
+      static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
+      static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
+      static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
+      static std::regex block_close_regex(R"(\s*([-~])?%\})");
+
+      TemplateTokenVector tokens;
+      std::vector<std::string> group;
+      std::string text;
+      std::smatch match;
+
+      try {
+        while (it != end) {
+          auto location = get_location();
+
+          if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+            auto content = group[2];
+            auto post_space = parsePostSpace(group[3]);
+            tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
+          } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+            auto expr = parseExpression();
+
+            if ((group = consumeTokenGroups(expr_close_regex)).empty()) {
+              throw std::runtime_error("Expected closing expression tag");
+            }
+
+            auto post_space = parsePostSpace(group[1]);
+            tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
+          } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+
+            std::string keyword;
+
+            auto parseBlockClose = [&]() -> SpaceHandling {
+              if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag");
+              return parsePostSpace(group[1]);
+            };
+
+            if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
+
+            if (keyword == "if") {
+              auto condition = parseExpression();
+              if (!condition) throw std::runtime_error("Expected condition in if block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
+            } else if (keyword == "elif") {
+              auto condition = parseExpression();
+              if (!condition) throw std::runtime_error("Expected condition in elif block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
+            } else if (keyword == "else") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "endif") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "for") {
+              static std::regex recursive_tok(R"(recursive\b)");
+              static std::regex if_tok(R"(if\b)");
+
+              auto varnames = parseVarNames();
+              static std::regex in_tok(R"(in\b)");
+              if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block");
+              auto iterable = parseExpression(/* allow_if_expr = */ false);
+              if (!iterable) throw std::runtime_error("Expected iterable in for block");
+
+              std::shared_ptr<Expression> condition;
+              if (!consumeToken(if_tok).empty()) {
+                condition = parseExpression();
+              }
+              auto recursive = !consumeToken(recursive_tok).empty();
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
+            } else if (keyword == "endfor") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "generation") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "endgeneration") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "set") {
+              static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
+
+              std::string ns;
+              std::vector<std::string> var_names;
+              std::shared_ptr<Expression> value;
+              if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
+                ns = group[1];
+                var_names.push_back(group[2]);
+
+                if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
+
+                value = parseExpression();
+                if (!value) throw std::runtime_error("Expected value in set block");
+              } else {
+                var_names = parseVarNames();
+
+                if (!consumeToken("=").empty()) {
+                  value = parseExpression();
+                  if (!value) throw std::runtime_error("Expected value in set block");
+                }
+              }
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
+            } else if (keyword == "endset") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "macro") {
+              auto macroname = parseIdentifier();
+              if (!macroname) throw std::runtime_error("Expected macro name in macro block");
+              auto params = parseParameters();
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
+            } else if (keyword == "endmacro") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "filter") {
+              auto filter = parseExpression();
+              if (!filter) throw std::runtime_error("Expected expression in filter block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
+            } else if (keyword == "endfilter") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "break" || keyword == "continue") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
+            } else {
+              throw std::runtime_error("Unexpected block: " + keyword);
+            }
+          } else if (std::regex_search(it, end, match, non_text_open_regex)) {
+            if (!match.position()) {
+                if (match[0] != "{#")
+                    throw std::runtime_error("Internal error: Expected a comment");
+                throw std::runtime_error("Missing end of comment tag");
+            }
+            auto text_end = it + match.position();
+            text = std::string(it, text_end);
+            it = text_end;
+            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+          } else {
+            text = std::string(it, end);
+            it = end;
+            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+          }
+        }
+        return tokens;
+      } catch (const std::exception & e) {
+        throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it)));
+      }
+    }
+
+    std::shared_ptr<TemplateNode> parseTemplate(
+          const TemplateTokenIterator & begin,
+          TemplateTokenIterator & it,
+          const TemplateTokenIterator & end,
+          bool fully = false) const {
+        std::vector<std::shared_ptr<TemplateNode>> children;
+        while (it != end) {
+          const auto start = it;
+          const auto & token = *(it++);
+          if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
+              std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+              cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
+
+              while (it != end && (*it)->type == TemplateToken::Type::Elif) {
+                  auto elif_token = dynamic_cast<ElifTemplateToken*>((*(it++)).get());
+                  cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end));
+              }
+
+              if (it != end && (*it)->type == TemplateToken::Type::Else) {
+                cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end));
+              }
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
+          } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              auto else_body = std::shared_ptr<TemplateNode>();
+              if (it != end && (*it)->type == TemplateToken::Type::Else) {
+                else_body = parseTemplate(begin, ++it, end);
+              }
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
+          } else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
+                  throw unterminated(**start);
+              }
+              // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
+              children.emplace_back(std::move(body));
+          } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
+              SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
+              SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
+
+              auto text = text_token->text;
+              if (post_space == SpaceHandling::Strip) {
+                static std::regex trailing_space_regex(R"(\s+$)");
+                text = std::regex_replace(text, trailing_space_regex, "");
+              } else if (options.lstrip_blocks && it != end) {
+                auto i = text.size();
+                while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
+                if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
+                  text.resize(i);
+                }
+              }
+              if (pre_space == SpaceHandling::Strip) {
+                static std::regex leading_space_regex(R"(^\s+)");
+                text = std::regex_replace(text, leading_space_regex, "");
+              } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
+                if (text.length() > 0 && text[0] == '\n') {
+                  text.erase(0, 1);
+                }
+              }
+              if (it == end && !options.keep_trailing_newline) {
+                auto i = text.size();
+                if (i > 0 && text[i - 1] == '\n') {
+                  i--;
+                  if (i > 0 && text[i - 1] == '\r') i--;
+                  text.resize(i);
+                }
+              }
+              children.emplace_back(std::make_shared<TextNode>(token->location, text));
+          } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
+              children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
+          } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
+            if (set_token->value) {
+              children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
+            } else {
+              auto value_template = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
+                  throw unterminated(**start);
+              }
+              if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
+              if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
+              auto & name = set_token->var_names[0];
+              children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
+            }
+          } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
+          } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
+          } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
+              // Ignore comments
+          } else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
+              children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
+          } else if (dynamic_cast<EndForTemplateToken*>(token.get())
+                  || dynamic_cast<EndSetTemplateToken*>(token.get())
+                  || dynamic_cast<EndMacroTemplateToken*>(token.get())
+                  || dynamic_cast<EndFilterTemplateToken*>(token.get())
+                  || dynamic_cast<EndIfTemplateToken*>(token.get())
+                  || dynamic_cast<ElseTemplateToken*>(token.get())
+                  || dynamic_cast<EndGenerationTemplateToken*>(token.get())
+                  || dynamic_cast<ElifTemplateToken*>(token.get())) {
+              it--;  // unconsume the token
+              break;  // exit the loop
+          } else {
+              throw unexpected(**(it-1));
+          }
+        }
+        if (fully && it != end) {
+            throw unexpected(**it);
+        }
+        if (children.empty()) {
+          return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
+        } else if (children.size() == 1) {
+          return std::move(children[0]);
+        } else {
+          return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
+        }
+    }
+
+public:
+
+    static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
+        Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
+        auto tokens = parser.tokenize();
+        TemplateTokenIterator begin = tokens.begin();
+        auto it = begin;
+        TemplateTokenIterator end = tokens.end();
+        return parser.parseTemplate(begin, it, end, /* full= */ true);
+    }
+};
+
+static Value simple_function(const std::string & fn_name, const std::vector<std::string> & params, const std::function<Value(const std::shared_ptr<Context> &, Value & args)> & fn) {
+  std::map<std::string, size_t> named_positions;
+  for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i;
+
+  return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) -> Value {
+    auto args_obj = Value::object();
+    std::vector<bool> provided_args(params.size());
+    for (size_t i = 0, n = args.args.size(); i < n; i++) {
+      auto & arg = args.args[i];
+      if (i < params.size()) {
+        args_obj.set(params[i], arg);
+        provided_args[i] = true;
+      } else {
+        throw std::runtime_error("Too many positional params for " + fn_name);
+      }
+    }
+    for (auto & [name, value] : args.kwargs) {
+      auto named_pos_it = named_positions.find(name);
+      if (named_pos_it == named_positions.end()) {
+        throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
+      }
+      provided_args[named_pos_it->second] = true;
+      args_obj.set(name, value);
+    }
+    return fn(context, args_obj);
+  });
+}
+
+inline std::shared_ptr<Context> Context::builtins() {
+  auto globals = Value::object();
+
+  globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+    throw std::runtime_error(args.at("message").get<std::string>());
+  }));
+  globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true));
+  }));
+  globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto items = Value::array();
+    if (args.contains("object")) {
+      auto & obj = args.at("object");
+      if (obj.is_string()) {
+        auto json_obj = json::parse(obj.get<std::string>());
+        for (const auto & kv : json_obj.items()) {
+          items.push_back(Value::array({kv.key(), kv.value()}));
+        }
+      } else if (!obj.is_null()) {
+        for (auto & key : obj.keys()) {
+          items.push_back(Value::array({key, obj.at(key)}));
+        }
+      }
+    }
+    return items;
+  }));
+  globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto items = args.at("items");
+    if (!items.is_array()) throw std::runtime_error("object is not a list");
+    if (items.size() == 0) return Value();
+    return items.at(items.size() - 1);
+  }));
+  globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto & text = args.at("text");
+    return text.is_null() ? text : Value(strip(text.get<std::string>()));
+  }));
+  globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto text = args.at("text");
+    if (text.is_null()) return text;
+    std::string res;
+    auto str = text.get<std::string>();
+    std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower);
+    return Value(res);
+  }));
+  globals.set("default", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    args.expectArgs("default", {2, 3}, {0, 1});
+    auto & value = args.args[0];
+    auto & default_value = args.args[1];
+    bool boolean = false;
+    if (args.args.size() == 3) {
+      boolean = args.args[2].get<bool>();
+    } else {
+      Value bv = args.get_named("boolean");
+      if (!bv.is_null()) {
+        boolean = bv.get<bool>();
+      }
+    }
+    return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value;
+  }));
+  auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value(html_escape(args.at("text").get<std::string>()));
+  });
+  globals.set("e", escape);
+  globals.set("escape", escape);
+  globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto sep = args.get<std::string>("sep", "");
+    auto first = std::make_shared<bool>(true);
+    return simple_function("", {}, [sep, first](const std::shared_ptr<Context> &, const Value &) -> Value {
+      if (*first) {
+        *first = false;
+        return "";
+      }
+      return sep;
+    });
+    return Value(html_escape(args.at("text").get<std::string>()));
+  }));
+  globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value((int64_t) args.at("items").size());
+  }));
+  globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr<Context> &, Value & args) {
+    if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)");
+    auto & value = args.at("value");
+    auto keys = value.keys();
+    std::sort(keys.begin(), keys.end());
+    auto res = Value::array();
+    for (auto & key : keys) {
+      res.push_back(Value::array({key, value.at(key)}));
+    }
+    return res;
+  }));
+  globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto do_join = [](Value & items, const std::string & sep) {
+      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+      std::ostringstream oss;
+      auto first = true;
+      for (size_t i = 0, n = items.size(); i < n; ++i) {
+        if (first) first = false;
+        else oss << sep;
+        oss << items.at(i).to_str();
+      }
+      return Value(oss.str());
+    };
+    auto sep = args.get<std::string>("d", "");
+    if (args.contains("items")) {
+        auto & items = args.at("items");
+        return do_join(items, sep);
+    } else {
+      return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr<Context> &, Value & args) {
+        auto & items = args.at("items");
+        if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump());
+        return do_join(items, sep);
+      });
+    }
+  }));
+  globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    auto ns = Value::object();
+    args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits<size_t>::max)()});
+    for (auto & [name, value] : args.kwargs) {
+      ns.set(name, value);
+    }
+    return ns;
+  }));
+  auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("actual") == args.at("expected");
+  });
+  globals.set("equalto", equalto);
+  globals.set("==", equalto);
+  globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      return (int64_t) items.size();
+  }));
+  globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_str();
+  }));
+  globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_str();
+  }));
+  globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_int();
+  }));
+  globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      if (!items.is_array()) throw std::runtime_error("object is not iterable");
+      return items;
+  }));
+  globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      if (!items.is_array()) throw std::runtime_error("object is not iterable");
+      std::unordered_set<Value> seen;
+      auto result = Value::array();
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto pair = seen.insert(items.at(i));
+        if (pair.second) {
+          result.push_back(items.at(i));
+        }
+      }
+      return result;
+  }));
+  auto make_filter = [](const Value & filter, Value & extra_args) -> Value {
+    return simple_function("", { "value" }, [=](const std::shared_ptr<Context> & context, Value & args) {
+      auto & value = args.at("value");
+      ArgumentsValue actual_args;
+      actual_args.args.emplace_back(value);
+      for (size_t i = 0, n = extra_args.size(); i < n; i++) {
+        actual_args.args.emplace_back(extra_args.at(i));
+      }
+      return filter.call(context, actual_args);
+    });
+  };
+  auto select_or_reject = [make_filter](bool is_select) {
+    return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+      args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+      auto & items = args.args[0];
+      if (items.is_null())
+        return Value::array();
+      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+
+      auto filter_fn = context->get(args.args[1]);
+      if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+
+      auto filter_args = Value::array();
+      for (size_t i = 2, n = args.args.size(); i < n; i++) {
+        filter_args.push_back(args.args[i]);
+      }
+      auto filter = make_filter(filter_fn, filter_args);
+
+      auto res = Value::array();
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto & item = items.at(i);
+        ArgumentsValue filter_args;
+        filter_args.args.emplace_back(item);
+        auto pred_res = filter.call(context, filter_args);
+        if (pred_res.to_bool() == (is_select ? true : false)) {
+          res.push_back(item);
+        }
+      }
+      return res;
+    });
+  };
+  globals.set("select", select_or_reject(/* is_select= */ true));
+  globals.set("reject", select_or_reject(/* is_select= */ false));
+  globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+    auto res = Value::array();
+    if (args.args.size() == 1 &&
+      ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) {
+      auto & items = args.args[0];
+      auto attr_name = args.get_named("attribute");
+      auto default_value = args.get_named("default");
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto & item = items.at(i);
+        auto attr = item.get(attr_name);
+        res.push_back(attr.is_null() ? default_value : attr);
+      }
+    } else if (args.kwargs.empty() && args.args.size() >= 2) {
+      auto fn = context->get(args.args[1]);
+      if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+      ArgumentsValue filter_args { {Value()}, {} };
+      for (size_t i = 2, n = args.args.size(); i < n; i++) {
+        filter_args.args.emplace_back(args.args[i]);
+      }
+      for (size_t i = 0, n = args.args[0].size(); i < n; i++) {
+        auto & item = args.args[0].at(i);
+        filter_args.args[0] = item;
+        res.push_back(fn.call(context, filter_args));
+      }
+    } else {
+      throw std::runtime_error("Invalid or unsupported arguments for map");
+    }
+    return res;
+  }));
+  globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto text = args.at("text").get<std::string>();
+    auto first = args.get<bool>("first", false);
+    std::string out;
+    std::string indent(args.get<int64_t>("indent", 0), ' ');
+    std::istringstream iss(text);
+    std::string line;
+    auto is_first = true;
+    while (std::getline(iss, line, '\n')) {
+      auto needs_indent = !is_first || first;
+      if (is_first) is_first = false;
+      else out += "\n";
+      if (needs_indent) out += indent;
+      out += line;
+    }
+    if (!text.empty() && text.back() == '\n') out += "\n";
+    return out;
+  }));
+  auto select_or_reject_attr = [](bool is_select) {
+    return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+      args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+      auto & items = args.args[0];
+      if (items.is_null())
+        return Value::array();
+      if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+      auto attr_name = args.args[1].get<std::string>();
+
+      bool has_test = false;
+      Value test_fn;
+      ArgumentsValue test_args {{Value()}, {}};
+      if (args.args.size() >= 3) {
+        has_test = true;
+        test_fn = context->get(args.args[2]);
+        if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
+        for (size_t i = 3, n = args.args.size(); i < n; i++) {
+          test_args.args.emplace_back(args.args[i]);
+        }
+        test_args.kwargs = args.kwargs;
+      }
+
+      auto res = Value::array();
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto & item = items.at(i);
+        auto attr = item.get(attr_name);
+        if (has_test) {
+          test_args.args[0] = attr;
+          if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
+            res.push_back(item);
+          }
+        } else {
+          res.push_back(attr);
+        }
+      }
+      return res;
+    });
+  };
+  globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
+  globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
+  globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    std::vector<int64_t> startEndStep(3);
+    std::vector<bool> param_set(3);
+    if (args.args.size() == 1) {
+      startEndStep[1] = args.args[0].get<int64_t>();
+      param_set[1] = true;
+    } else {
+      for (size_t i = 0; i < args.args.size(); i++) {
+        auto & arg = args.args[i];
+        auto v = arg.get<int64_t>();
+        startEndStep[i] = v;
+        param_set[i] = true;
+        }
+      }
+      for (auto & [name, value] : args.kwargs) {
+        size_t i;
+        if (name == "start") i = 0;
+        else if (name == "end") i = 1;
+        else if (name == "step") i = 2;
+        else throw std::runtime_error("Unknown argument " + name + " for function range");
+
+        if (param_set[i]) {
+          throw std::runtime_error("Duplicate argument " + name + " for function range");
+        }
+        startEndStep[i] = value.get<int64_t>();
+        param_set[i] = true;
+    }
+    if (!param_set[1]) {
+      throw std::runtime_error("Missing required argument 'end' for function range");
+    }
+    int64_t start = param_set[0] ? startEndStep[0] : 0;
+    int64_t end = startEndStep[1];
+    int64_t step = param_set[2] ? startEndStep[2] : 1;
+
+    auto res = Value::array();
+    if (step > 0) {
+      for (int64_t i = start; i < end; i += step) {
+        res.push_back(Value(i));
+      }
+    } else {
+      for (int64_t i = start; i > end; i += step) {
+        res.push_back(Value(i));
+      }
+    }
+    return res;
+  }));
+
+  return std::make_shared<Context>(std::move(globals));
+}
+
+inline std::shared_ptr<Context> Context::make(Value && values, const std::shared_ptr<Context> & parent) {
+  return std::make_shared<Context>(values.is_null() ? Value::object() : std::move(values), parent);
+}
+
+}  // namespace minja
index e654d3542c6c3827e343335d51bc9eec74e29514..cf8659b037ee3757d8027d343de92a9e5e3e8fbb 100644 (file)
@@ -4,7 +4,7 @@
 #include "log.h"
 #include "sampling.h"
 #include "llama.h"
-#include "chat-template.hpp"
+#include "chat.h"
 
 #include <cstdio>
 #include <cstring>
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
     }
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
-    auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
+    auto chat_templates = common_chat_templates_init(model, params.chat_template);
 
     LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
 
@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
     }
 
     // auto enable conversation mode if chat template is available
-    const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
+    const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
     if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
         if (has_chat_template) {
             LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
     // print chat template example in conversation mode
     if (params.conversation_mode) {
         if (params.enable_chat_template) {
-            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
+            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
         } else {
             LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
         }
@@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd_inp;
 
     auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
-        common_chat_msg new_msg{role, content, {}};
-        auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
-        chat_msgs.push_back({role, content, {}});
+        common_chat_msg new_msg;
+        new_msg.role = role;
+        new_msg.content = content;
+        auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
+        chat_msgs.push_back(new_msg);
         LOG_DBG("formatted: '%s'\n", formatted.c_str());
         return formatted;
     };
@@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
 
                 // check for reverse prompt using special tokens
                 llama_token last_token = common_sampler_last(smpl);
-                if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
-                    if (params.interactive) {
-                        is_interacting = true;
+                for (auto token : antiprompt_token) {
+                    if (token == last_token) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                        }
+                        is_antiprompt = true;
+                        break;
                     }
-                    is_antiprompt = true;
                 }
 
                 if (is_antiprompt) {
index 9362da22083d3833c0ab2ab9f19c1844e6a19e4e..ed8644ef78d97fb064e665d283db1b61dcdaeda0 100644 (file)
@@ -24,7 +24,7 @@
 #include <string>
 #include <vector>
 
-#include "chat-template.hpp"
+#include "chat.h"
 #include "common.h"
 #include "json.hpp"
 #include "linenoise.cpp/linenoise.h"
@@ -557,7 +557,7 @@ class LlamaData {
     llama_model_ptr                 model;
     llama_sampler_ptr               sampler;
     llama_context_ptr               context;
-    std::vector<llama_chat_message> messages;
+    std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
     std::list<std::string>          msg_strs;
     std::vector<char>               fmtted;
 
@@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
 }
 
 // Function to apply the chat template and resize `formatted` if needed
-static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
-    if (use_jinja) {
-        json messages = json::array();
-        for (const auto & msg : llama_data.messages) {
-            messages.push_back({
-                {"role", msg.role},
-                {"content", msg.content},
-            });
-        }
-        try {
-            minja::chat_template_inputs tmpl_inputs;
-            tmpl_inputs.messages = messages;
-            tmpl_inputs.add_generation_prompt = append;
-
-            minja::chat_template_options tmpl_opts;
-            tmpl_opts.use_bos_token = false;
-            tmpl_opts.use_eos_token = false;
-
-            auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
-            llama_data.fmtted.resize(result.size() + 1);
-            memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
-            return result.size();
-        } catch (const std::exception & e) {
-            printe("failed to render the chat template: %s\n", e.what());
-            return -1;
-        }
-    }
-    int result = llama_chat_apply_template(
-        tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
-        append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
-    if (append && result > static_cast<int>(llama_data.fmtted.size())) {
-        llama_data.fmtted.resize(result);
-        result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
-                                           llama_data.messages.size(), append, llama_data.fmtted.data(),
-                                           llama_data.fmtted.size());
-    }
-
-    return result;
+static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
+    common_chat_templates_inputs inputs;
+    for (const auto & msg : llama_data.messages) {
+        common_chat_msg cmsg;
+        cmsg.role    = msg.role;
+        cmsg.content = msg.content;
+        inputs.messages.push_back(cmsg);
+    }
+    inputs.add_generation_prompt = append;
+    inputs.use_jinja = use_jinja;
+
+    auto chat_params = common_chat_templates_apply(tmpls, inputs);
+    // TODO: use other params for tool calls.
+    auto result = chat_params.prompt;
+    llama_data.fmtted.resize(result.size() + 1);
+    memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
+    return result.size();
 }
 
 // Function to tokenize the prompt
@@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
 }
 
 // Helper function to apply the chat template and handle errors
-static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
-    const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
+static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
+    const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
     if (new_len < 0) {
         printe("failed to apply the chat template\n");
         return -1;
@@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
 static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
-    auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
-    GGML_ASSERT(chat_templates.template_default);
+    auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
@@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
 
         add_message("user", user.empty() ? user_input : user, llama_data);
         int new_len;
-        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
+        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
             return 1;
         }
 
@@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
         }
 
         add_message("assistant", response, llama_data);
-        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
+        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
             return 1;
         }
     }
index 5707c766d7e05d6ab3a862bac9ea686cc9ebff9c..809bfe0e36cd7ee4d21da8fb14484b0fa701817d 100644 (file)
@@ -329,9 +329,6 @@ struct server_task {
         }
 
         // process "json_schema" and "grammar"
-        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
-            throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
-        }
         if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
                 auto schema                  = json_value(data, "json_schema", json::object());
@@ -1807,7 +1804,7 @@ struct server_context {
     // Necessary similarity of prompt for slot selection
     float slot_prompt_similarity = 0.0f;
 
-    common_chat_templates chat_templates;
+    common_chat_templates_ptr chat_templates;
 
     ~server_context() {
         // Clear any sampling context
@@ -1891,45 +1888,17 @@ struct server_context {
             llama_init_dft.context.reset();
         }
 
-        if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
+        chat_templates = common_chat_templates_init(model, params_base.chat_template);
+        try {
+            common_chat_format_example(chat_templates.get(), params.use_jinja);
+        } catch (const std::exception & e) {
             SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
-            chat_templates = common_chat_templates_from_model(model, "chatml");
-        } else {
-            chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
+            chat_templates = common_chat_templates_init(model, "chatml");
         }
-        GGML_ASSERT(chat_templates.template_default.get() != nullptr);
 
         return true;
     }
 
-    bool validate_builtin_chat_template(bool use_jinja) const {
-        llama_chat_message chat[] = {{"user", "test"}};
-
-        if (use_jinja) {
-            auto templates = common_chat_templates_from_model(model, "");
-            common_chat_inputs inputs;
-            inputs.messages = json::array({{
-                {"role", "user"},
-                {"content", "test"},
-            }});
-            GGML_ASSERT(templates.template_default);
-            try {
-                common_chat_params_init(*templates.template_default, inputs);
-                if (templates.template_tool_use) {
-                    common_chat_params_init(*templates.template_tool_use, inputs);
-                }
-                return true;
-            } catch (const std::exception & e) {
-                SRV_ERR("failed to apply template: %s\n", e.what());
-                return false;
-            }
-        } else {
-            const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
-            const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
-            return chat_res > 0;
-        }
-    }
-
     void init() {
         const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
 
@@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model },
-            { "chat_template",               ctx_server.chat_templates.template_default->source() },
-            { "bos_token",                   ctx_server.chat_templates.template_default->bos_token() },
-            { "eos_token",                   ctx_server.chat_templates.template_default->eos_token() },
+            { "chat_template",               common_chat_templates_source(ctx_server.chat_templates.get()) },
+            { "bos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
+            { "eos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
             { "build_info",                  build_info },
         };
-        if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
-            data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
+        if (ctx_server.params_base.use_jinja) {
+            if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
+                data["chat_template_tool_use"] = tool_use_src;
+            }
         }
 
         res_ok(res, data);
@@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
         }
 
         auto body = json::parse(req.body);
-        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
 
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
@@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
     // same with handle_chat_completions, but without inference part
     const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
         auto body = json::parse(req.body);
-        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
         res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
     };
 
@@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) {
 
     // print sample chat example to make it clear which template is used
     LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
-        ctx_server.chat_templates.template_default->source().c_str(),
-        common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
+        common_chat_templates_source(ctx_server.chat_templates.get()),
+        common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
 
     ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
         ctx_server.process_single_task(task);
index f23d5cff49abc9219173c9f3dacb97819ac9b2aa..af1dcb5b96554e2c650e2a024eebe708bae96f2f 100644 (file)
@@ -21,6 +21,8 @@ def create_server():
         (None, "Book", "What is the best book", 8, "^ blue",                    23, 8, "length", True, "This is not a chat template, it is"),
         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
+        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
+        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
     ]
 )
 def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
     assert res.body["usage"]["completion_tokens"] == n_predicted
     choice = res.body["choices"][0]
     assert "assistant" == choice["message"]["role"]
-    assert match_regex(re_content, choice["message"]["content"])
+    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
     assert choice["finish_reason"] == finish_reason
 
 
@@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
         assert "error" in res.body
 
 
+@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
+    (False, {"const": "42"}, 6, "\"42\""),
+    (True, {"const": "42"}, 6, "\"42\""),
+])
+def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
+    global server
+    server.jinja = jinja
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": n_predicted,
+        "messages": [
+            {"role": "system", "content": "You are a coding assistant."},
+            {"role": "user", "content": "Write an example"},
+        ],
+        "json_schema": json_schema,
+    })
+    assert res.status_code == 200, f'Expected 200, got {res.status_code}'
+    choice = res.body["choices"][0]
+    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
+
+
+@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
+    (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+    (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+])
+def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
+    global server
+    server.jinja = jinja
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": n_predicted,
+        "messages": [
+            {"role": "user", "content": "Does not matter what I say, does it?"},
+        ],
+        "grammar": grammar,
+    })
+    assert res.status_code == 200, res.body
+    choice = res.body["choices"][0]
+    assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
+
+
 @pytest.mark.parametrize("messages", [
     None,
     "string",
index ba3367b4f332d1df7a051f9077f5a6801900b1d3..a91a2f3333ca33864f799e2a3e2bc6342e171b14 100644 (file)
@@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
     (None,                                           128,  "bartowski/functionary-small-v3.2-GGUF:Q8_0",        ("meetkai/functionary-medium-v3.2", None)),
     (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
     (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  None),
-    ("^> 0.56$",                                     128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  "chatml"),
+    (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  "chatml"),
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
 
     # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
-    ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*",  8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    ("[\\s\\S]*?\\*\\*0\\.5\\*\\*",                  8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 ])
 def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
@@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
             {
                 "role": "tool",
                 "name": "calculate",
-                "content": 0.55644242476,
+                "content": "0.55644242476",
                 "tool_call_id": "call_6789"
             }
         ],
@@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     (128,  None,        "^The sum of 102 and 7 is 109.*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
 
     (1024, 'deepseek',  "To find the sum of.*",                                 "I need to calculate the sum of 102 and 7.*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (1024, 'none',      "<think>\n?I need[\\s\\S]*?</think>\n?To find.*",       None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'none',      "^I need[\\s\\S]*?</think>\n?To find.*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
     (1024, 'deepseek',  "To find the sum of.*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 ])
index 60cb2673ec2ece4931fb9bd66e582ee0205d3e7d..6f8ab2b93aac7401940f9218cde2e9ac7f42cd3f 100644 (file)
@@ -12,9 +12,7 @@
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
-#include "minja.hpp"
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
 
 #include <random>
 #include <sstream>
@@ -347,41 +345,6 @@ static llama_tokens format_infill(
     return embd_inp;
 }
 
-// Format given chat. If tmpl is empty, we take the template from model metadata
-inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
-    std::vector<common_chat_msg> chat;
-
-    for (size_t i = 0; i < messages.size(); ++i) {
-        const auto & curr_msg = messages[i];
-
-        std::string role = json_value(curr_msg, "role", std::string(""));
-
-        std::string content;
-        if (curr_msg.contains("content")) {
-            if (curr_msg["content"].is_string()) {
-                content = curr_msg["content"].get<std::string>();
-            } else if (curr_msg["content"].is_array()) {
-                for (const auto & part : curr_msg["content"]) {
-                    if (part.contains("text")) {
-                        content += "\n" + part["text"].get<std::string>();
-                    }
-                }
-            } else {
-                throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
-            }
-        } else {
-            throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
-        }
-
-        chat.push_back({role, content, /* tool_calls= */ {}});
-    }
-
-    const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
-    LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
-
-    return formatted_chat;
-}
-
 //
 // base64 utils (TODO: move to common in the future)
 //
@@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
     const json & body, /* openai api json semantics */
     bool use_jinja,
     common_reasoning_format reasoning_format,
-    const common_chat_templates & chat_templates)
+    const struct common_chat_templates * tmpls)
 {
     json llama_params;
-    const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
-        ? *chat_templates.template_tool_use
-        : *chat_templates.template_default;
 
     auto tools = json_value(body, "tools", json());
     auto stream = json_value(body, "stream", false);
@@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse(
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
 
+    auto json_schema = json_value(body, "json_schema", json());
+    auto grammar = json_value(body, "grammar", std::string());
+    if (!json_schema.is_null() && !grammar.empty()) {
+        throw std::runtime_error("Cannot use both json_schema and grammar");
+    }
+
     // Handle "response_format" field
     if (body.contains("response_format")) {
         json response_format      = json_value(body, "response_format", json::object());
         std::string response_type = json_value(response_format, "type", std::string());
         if (response_type == "json_object") {
-            llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+            json_schema = json_value(response_format, "schema", json::object());
         } else if (response_type == "json_schema") {
             json json_schema = json_value(response_format, "json_schema", json::object());
-            llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
+            json_schema = json_value(json_schema, "schema", json::object());
         } else if (!response_type.empty() && response_type != "text") {
             throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
         }
     }
 
+    common_chat_templates_inputs inputs;
+    inputs.messages              = common_chat_msgs_parse_oaicompat(body.at("messages"));
+    inputs.tools                 = common_chat_tools_parse_oaicompat(tools);
+    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+    inputs.json_schema           = json_schema.is_null() ? "" : json_schema.dump();
+    inputs.grammar               = grammar;
+    inputs.add_generation_prompt = true;
+    inputs.use_jinja             = use_jinja;
+    inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
+    inputs.extract_reasoning     = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+    if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
+        throw std::runtime_error("Cannot use custom grammar constraints with tools.");
+    }
+
     // Apply chat template to the list of messages
-    if (use_jinja) {
-        auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
-        if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
-            throw std::runtime_error("Invalid tool_choice: " + tool_choice);
-        }
-        if (tool_choice != "none" && llama_params.contains("grammar")) {
-            throw std::runtime_error("Cannot use custom grammar constraints with tools.");
-        }
-        common_chat_inputs inputs;
-        inputs.extract_reasoning   = reasoning_format != COMMON_REASONING_FORMAT_NONE;
-        inputs.messages            = body.at("messages");
-        inputs.tools               = tools;
-        inputs.tool_choice         = tool_choice;
-        inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
-        if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
-            LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
-            inputs.parallel_tool_calls = false;
-        }
-        inputs.stream = stream;
-        // TODO: support mixing schema w/ tools beyond generic format.
-        inputs.json_schema = json_value(llama_params, "json_schema", json());
-        auto chat_params = common_chat_params_init(tmpl, inputs);
-
-        llama_params["chat_format"] = static_cast<int>(chat_params.format);
-        llama_params["prompt"] = chat_params.prompt;
-        llama_params["grammar"] = chat_params.grammar;
-        llama_params["grammar_lazy"] = chat_params.grammar_lazy;
-        auto grammar_triggers = json::array();
-        for (const auto & trigger : chat_params.grammar_triggers) {
-            grammar_triggers.push_back({
-                {"word", trigger.word},
-                {"at_start", trigger.at_start},
-            });
-        }
-        llama_params["grammar_triggers"] = grammar_triggers;
-        llama_params["preserved_tokens"] = chat_params.preserved_tokens;
-        for (const auto & stop : chat_params.additional_stops) {
-            llama_params["stop"].push_back(stop);
-        }
-    } else {
-        llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
+    auto chat_params = common_chat_templates_apply(tmpls, inputs);
+
+    llama_params["chat_format"]      = static_cast<int>(chat_params.format);
+    llama_params["prompt"]           = chat_params.prompt;
+    llama_params["grammar"]          = chat_params.grammar;
+    llama_params["grammar_lazy"]     = chat_params.grammar_lazy;
+    auto grammar_triggers = json::array();
+    for (const auto & trigger : chat_params.grammar_triggers) {
+        grammar_triggers.push_back({
+            {"word", trigger.word},
+            {"at_start", trigger.at_start},
+        });
+    }
+    llama_params["grammar_triggers"] = grammar_triggers;
+    llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+    for (const auto & stop : chat_params.additional_stops) {
+        llama_params["stop"].push_back(stop);
     }
 
     // Handle "n" field
index e0314ae1d62966e541ff8f2d063452822471a65f..9231c517afb0b85d0df60e3868076addd60cce64 100644 (file)
@@ -1,13 +1,14 @@
 #include <string>
 #include <vector>
 #include <sstream>
+#include <regex>
 
 #undef NDEBUG
 #include <cassert>
 
 #include "llama.h"
 #include "common.h"
-#include "chat-template.hpp"
+#include "chat.h"
 
 static std::string normalize_newlines(const std::string & s) {
 #ifdef _WIN32
@@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
 #endif
 }
 
+static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
+    common_chat_msg msg;
+    msg.role = role;
+    msg.content = content;
+    return msg;
+}
+
 int main(void) {
     std::vector<llama_chat_message> conversation {
         {"system", "You are a helpful assistant"},
@@ -50,7 +58,7 @@ int main(void) {
             /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
             /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
             /* .expected_output_jinja= */ "",
-            /* .bos_token= */ "",
+            /* .bos_token= */ "<s>",
             /* .eos_token= */ "</s>",
         },
         {
@@ -72,8 +80,8 @@ int main(void) {
         {
             /* .name= */ "mlabonne/AlphaMonarch-7B",
             /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
-            /* .expected_output= */          "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
-            /* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+            /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+            /* .expected_output_jinja= */ "",
             /* .bos_token= */ "<s>",
             /* .eos_token= */ "</s>",
         },
@@ -87,7 +95,7 @@ int main(void) {
             /* .name= */ "OrionStarAI/Orion-14B-Chat",
             /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
             /* .expected_output= */       "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: </s>",
-            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: </s>",
+            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: ",
             /* .bos_token= */ "",
             /* .eos_token= */ "</s>",
         },
@@ -304,12 +312,9 @@ int main(void) {
         }
     }
 
-    json messages = json::array();
+    std::vector<common_chat_msg> messages;
     for (const auto & msg : conversation) {
-        messages.push_back({
-            {"role", msg.role},
-            {"content", msg.content},
-        });
+        messages.push_back(simple_msg(msg.role, msg.content));
     }
     for (const auto & test_case : test_cases) {
         if (!test_case.supported_with_jinja) {
@@ -317,8 +322,13 @@ int main(void) {
         }
         printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
         try {
-            minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
-            auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
+            auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
+            common_chat_templates_inputs inputs;
+            inputs.use_jinja = true;
+            inputs.messages = messages;
+            inputs.add_generation_prompt = add_generation_prompt;
+            auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
+            output = normalize_newlines(output);
             auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
             if (output != expected_output) {
                 printf("Expected:\n%s\n", expected_output.c_str());
@@ -336,11 +346,11 @@ int main(void) {
     // test llama_chat_format_single for system message
     printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
     std::vector<common_chat_msg> chat2;
-    common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
+    auto sys_msg = simple_msg("system", "You are a helpful assistant");
 
     auto fmt_sys = [&](std::string tmpl_str) {
-        minja::chat_template tmpl(tmpl_str, "", "");
-        auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
+        auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
         printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;
@@ -360,14 +370,14 @@ int main(void) {
 
     // test llama_chat_format_single for user message
     printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
-    chat2.push_back({"system", "You are a helpful assistant", {}});
-    chat2.push_back({"user", "Hello", {}});
-    chat2.push_back({"assistant", "I am assistant", {}});
-    common_chat_msg new_msg{"user", "How are you", {}};
+    chat2.push_back(simple_msg("system", "You are a helpful assistant"));
+    chat2.push_back(simple_msg("user", "Hello"));
+    chat2.push_back(simple_msg("assistant", "I am assistant"));
+    auto new_msg = simple_msg("user", "How are you");
 
-    auto fmt_single = [&](std::string tmpl_str) {
-        minja::chat_template tmpl(tmpl_str, "", "");
-        auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
+    auto fmt_single = [&](const std::string & tmpl_str) {
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
+        auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
         printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;
index 2836caf6a71a3cc70f358863000a43e5dd7a45c3..64359230548591aa9f733d3d103d8b6786775731 100644 (file)
 #include <json.hpp>
 #include <string>
 
-#include "chat-template.hpp"
-#include "chat.hpp"
+#include "chat.h"
 #include "llama-grammar.h"
 #include "unicode.h"
 
 using json = nlohmann::ordered_json;
 
-static common_chat_msg msg_from_json(const json & message) {
-    common_chat_msg ret;
-    ret.role = "assistant";
-    if (message.contains("content") && !message.at("content").is_null()) {
-        ret.content = message.at("content");
-    }
-    if (message.contains("tool_plan")) {
-        ret.reasoning_content = message.at("tool_plan");
-    }
-    if (message.contains("reasoning_content")) {
-        ret.reasoning_content = message.at("reasoning_content");
-    }
-    auto has_tool_calls = message.contains("tool_calls");
-    if (has_tool_calls) {
-        for (const auto & tc : message.at("tool_calls")) {
-            const auto & arguments = tc.at("function").at("arguments");
-            ret.tool_calls.push_back({
-                tc.at("function").at("name").get<std::string>(),
-                arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
-                tc.contains("id") ? tc.at("id").get<std::string>() : "",
-            });
-        }
-    }
-    return ret;
-}
 
 template <class T> static void assert_equals(const T & expected, const T & actual) {
     if (expected != actual) {
@@ -53,7 +27,7 @@ template <class T> static void assert_equals(const T & expected, const T & actua
 }
 
 static std::string read_file(const std::string & path) {
-    std::cerr << "# Reading: " << path << std::endl << std::flush;
+    std::cerr << "# Reading: " << path << '\n' << std::flush;
     std::ifstream fs(path, std::ios_base::binary);
     if (!fs.is_open()) {
         fs = std::ifstream("../" + path, std::ios_base::binary);
@@ -66,10 +40,14 @@ static std::string read_file(const std::string & path) {
     fs.seekg(0);
     std::string out;
     out.resize(static_cast<size_t>(size));
-    fs.read(&out[0], static_cast<std::streamsize>(size));
+    fs.read(out.data(), static_cast<std::streamsize>(size));
     return out;
 }
 
+static common_chat_templates_ptr read_templates(const std::string & path) {
+    return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
+}
+
 static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
     return std::unique_ptr<llama_grammar>(
         llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
@@ -90,110 +68,102 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
         }
     }
 
-    for (const auto & stack : stacks_cur) {
-        if (stack.empty()) {
-            // An empty stack means that the grammar has been completed
-            return true;
-        }
+    if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
+        // An empty stack means that the grammar has been completed
+        return true;
     }
 
     return false;
 }
 
-// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
-static std::string dump(const json & j) {
-    return minja::Value(j).dump(-1, /* to_json= */ true);
-}
-
 static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
     assert_equals(expected.role, actual.role);
     assert_equals(expected.content, actual.content);
+    assert_equals(expected.content_parts.size(), actual.content_parts.size());
+    for (size_t i = 0; i < expected.content_parts.size(); i++) {
+        const auto & expected_part = expected.content_parts[i];
+        const auto & actual_part   = actual.content_parts[i];
+        assert_equals(expected_part.type, actual_part.type);
+        assert_equals(expected_part.text, actual_part.text);
+    }
     assert_equals(expected.reasoning_content, actual.reasoning_content);
     assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
     for (size_t i = 0; i < expected.tool_calls.size(); i++) {
         const auto & expected_tool_call = expected.tool_calls[i];
         const auto & actual_tool_call   = actual.tool_calls[i];
         assert_equals(expected_tool_call.name, actual_tool_call.name);
-        assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
+        assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
         assert_equals(expected_tool_call.id, actual_tool_call.id);
     }
 }
 
-const auto special_function_tool = json::parse(R"({
-  "type": "function",
-  "function": {
-    "name": "special_function",
-    "description": "I'm special",
-    "parameters": {
-      "type": "object",
-      "properties": {
-        "arg1": {
-          "type": "integer",
-          "description": "The arg."
-        }
-      },
-      "required": ["arg1"]
-    }
-  }
-})");
-const auto python_tool           = json::parse(R"({
-  "type": "function",
-  "function": {
-    "name": "python",
-    "description": "an ipython interpreter",
-    "parameters": {
-      "type": "object",
-      "properties": {
-        "code": {
-          "type": "string",
-          "description": "Python code to execute."
-        }
-      },
-      "required": ["code"]
-    }
-  }
-})");
-const auto code_interpreter_tool = json::parse(R"({
-  "type": "function",
-  "function": {
-    "name": "code_interpreter",
-    "description": "an ipython interpreter",
-    "parameters": {
-      "type": "object",
-      "properties": {
-        "code": {
-          "type": "string",
-          "description": "Python code to execute."
-        }
-      },
-      "required": ["code"]
-    }
-  }
-})");
-const json tools                 = { special_function_tool, python_tool };
-const json llama_3_1_tools       = { special_function_tool, code_interpreter_tool };
+common_chat_tool special_function_tool {
+    /* .name = */ "special_function",
+    /* .description = */ "I'm special",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "arg1": {
+                "type": "integer",
+                "description": "The arg."
+            }
+        },
+        "required": ["arg1"]
+    })",
+};
+common_chat_tool python_tool {
+    /* .name = */ "python",
+    /* .description = */ "an ipython interpreter",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "code": {
+                "type": "string",
+                "description": "Python code to execute."
+            }
+        },
+        "required": ["code"]
+    })",
+};
+common_chat_tool code_interpreter_tool {
+    /* .name = */ "code_interpreter",
+    /* .description = */ "an ipython interpreter",
+    /* .parameters = */ R"({
+        "type": "object",
+        "properties": {
+            "code": {
+                "type": "string",
+                "description": "Python code to execute."
+            }
+        },
+        "required": ["code"]
+    })",
+};
+std::vector<common_chat_tool> tools           { special_function_tool, python_tool };
+std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
 
 struct delta_data {
     std::string        delta;
     common_chat_params params;
 };
 
-static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
-                             const json & user_message, const json & delta_message, const json & tools,
-                             const json & tool_choice,
+static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
+                             const common_chat_msg & user_message,
+                             const common_chat_msg & delta_message,
+                             const std::vector<common_chat_tool> & tools,
+                             const common_chat_tool_choice & tool_choice,
                              bool think = false) {
-    common_chat_inputs inputs;
+    common_chat_templates_inputs inputs;
     inputs.parallel_tool_calls = true;
-    inputs.messages            = json::array();
     inputs.messages.push_back(user_message);
     inputs.tools       = tools;
     inputs.tool_choice = tool_choice;
     inputs.extract_reasoning = think;
-    auto params_prefix = common_chat_params_init(tmpl, inputs);
+    auto params_prefix = common_chat_templates_apply(tmpls, inputs);
 
     inputs.messages.push_back(delta_message);
     inputs.add_generation_prompt = false;
-    auto params_full             = common_chat_params_init(tmpl, inputs);
+    auto params_full             = common_chat_templates_apply(tmpls, inputs);
 
     std::string prefix = params_prefix.prompt;
     std::string full   = params_full.prompt;
@@ -234,30 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
   gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
   the parsed message is the same as the test_message
 */
-static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
-                          const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
+static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
+                          const common_chat_msg & test_message,
+                          const std::vector<common_chat_tool> & tools = {},
+                          const std::string & expected_delta = "",
                           bool expect_grammar_triggered = true,
                           bool test_grammar_if_triggered = true,
                           bool think = false) {
-    common_chat_msg expected_msg = msg_from_json(test_message);
-
-    auto user_message = json{
-        { "role",    "user"          },
-        { "content", "Hello, world!" }
-    };
+    common_chat_msg user_message;
+    user_message.role = "user";
+    user_message.content = "Hello, world!";
 
-    for (const auto & tool_choice : json({ "auto", "required" })) {
-        auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
+    for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
+        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
         if (!expected_delta.empty()) {
             assert_equals(expected_delta, data.delta);
         }
 
         if (expect_grammar_triggered) {
             const auto msg = common_chat_parse(data.delta, data.params.format);
-            assert_msg_equals(expected_msg, msg);
+            assert_msg_equals(test_message, msg);
         }
 
-        if (!expected_msg.tool_calls.empty()) {
+        if (!test_message.tool_calls.empty()) {
             GGML_ASSERT(!data.params.grammar.empty());
         }
         if (!data.params.grammar.empty()) {
@@ -297,246 +266,339 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
     }
 }
 
-static void test_template_output_parsers() {
-    json message_user {
-        { "role",    "user"     },
-        { "content", "Hey there!" },
-    };
-    json message_assist {
-        { "role",    "assistant"     },
-        { "content", "Hello, world!\nWhat's up?" },
-    };
-    json message_assist_thoughts_unparsed_think {
-        { "role",    "assistant"     },
-        { "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
-    };
-    json message_assist_thoughts_unparsed_r7b {
-        { "role",    "assistant"     },
-        { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
-    };
-    json message_assist_thoughts {
-        { "role",    "assistant"     },
-        { "content", "Hello, world!\nWhat's up?" },
-        { "reasoning_content", "I'm thinking" },
-    };
-    json tool_calls = json::array({{
-        { "type", "function" },
-        { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
-    }});
-
-    json message_assist_call {
-        { "role",       "assistant"},
-        { "content",    {}},
-        { "tool_calls", {
-            {
-                { "type", "function" },
-                { "function", {
-                    { "name", "special_function" },
-                    { "arguments", "{\"arg1\": 1}" },
-                }},
-            },
-        }},
-    };
-    json message_assist_call_thoughts = {
-        { "role",       "assistant"                },
-        { "content",    nullptr                    },
-        { "reasoning_content",   "I'm\nthinking"              },
-        { "tool_calls",  {
-            {
-                { "type", "function" },
-                { "function", {
-                    { "name", "special_function" },
-                    { "arguments", "{\"arg1\": 1}" },
-                }},
-            },
-        }},
-    };
-    json message_assist_call_thoughts_unparsed = {
-        { "role",       "assistant"                },
-        { "content",    "<think>I'm\nthinking</think>" },
-        { "tool_calls",  {
-            {
-                { "type", "function" },
-                { "function", {
-                    { "name", "special_function" },
-                    { "arguments", "{\"arg1\": 1}" },
-                }},
-            },
-        }},
-    };
-    json message_assist_call_id {
-        { "role",       "assistant"},
-        { "content",    {}},
-        { "tool_calls", {
-            {
-                { "type", "function" },
-                { "function", {
-                    { "name", "special_function" },
-                    { "arguments", "{\"arg1\": 1}" },
-                }},
-                {"id", "123456789"},
-            },
-        }},
-        { "role",       "assistant"                },
-        { "content",    {}                         },
-        { "tool_calls", tool_calls                  }
-    };
-    json message_assist_call_idx {
-        { "role",       "assistant"},
-        { "content",    {}},
-        { "tool_calls", {
-            {
-                { "type", "function" },
-                { "function", {
-                    { "name", "special_function" },
-                    { "arguments", "{\"arg1\": 1}" },
-                }},
-                // Index of the tool call in the tool_calls array
-                {"id", "0"},
-            },
-        }},
-        { "role",       "assistant"                },
-        { "content",    {}                         },
-        { "tool_calls", tool_calls                  }
-    };
-    json message_assist_call_tool_plan_idx = message_assist_call_idx;
-    message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
-
-    auto python_message_assist_call = json{
-        { "role",       "assistant"                },
-        { "content",    {}                         },
-        { "tool_calls", json{ {
-                            { "type", "function" },
-                            { "function",
-                              {
-                                  { "name", "python" },
-                                  { "arguments",
-                                    {
-                                        { "code", "print('hey')" },
-                                    } },
-                              } },
-                        } } }
+const common_chat_msg message_user {
+    "user",
+    "Hey there!",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+
+const common_chat_msg message_user_parts {
+    "user",
+    /* .content = */ "",
+    /* .content_parts = */ {
+        { "text", "Hey" },
+        { "text", "there" },
+    },
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist {
+    "assistant",
+    "Hello, world!\nWhat's up?",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts_unparsed_think {
+    "assistant",
+    "<think>I'm thinking</think>Hello, world!\nWhat's up?",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts_unparsed_r7b {
+    "assistant",
+    "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts {
+    "assistant",
+    "Hello, world!\nWhat's up?",
+    /* .content_parts = */ {},
+    /* .tool_calls = */ {},
+    /* .reasoning_content = */ "I'm thinking",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const std::vector<common_chat_tool_call> tool_calls {
+    { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
+};
+const std::vector<common_chat_tool_call> tool_calls_idx {
+    { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
+};
+const std::vector<common_chat_tool_call> tool_calls_id {
+    { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
+};
+
+const common_chat_msg message_assist_call {
+    "assistant",
+    "",
+    /* .content_parts = */ {},
+    tool_calls,
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_thoughts = {
+    "assistant",
+    /* .content = */ "",
+    /* .content_parts = */ {},
+    tool_calls,
+    /* .reasoning_content = */ "I'm\nthinking",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_thoughts_unparsed = {
+    "assistant",
+    /* .content = */ "<think>I'm\nthinking</think>",
+    /* .content_parts = */ {},
+    tool_calls,
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_id {
+    "assistant",
+    "",
+    /* .content_parts = */ {},
+    tool_calls_id,
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_idx {
+    "assistant",
+    "",
+    /* .content_parts = */ {},
+    tool_calls_idx,
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_python {
+    "assistant",
+    "",
+    /* .content_parts = */ {},
+    { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_code_interpreter {
+    "assistant",
+    "",
+    /* .content_parts = */ {},
+    { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
+    /* .reasoning_content = */ "",
+    /* .tool_name = */ "",
+    /* .tool_call_id = */ "",
+};
+
+static void test_msgs_oaicompat_json_conversion() {
+    std::vector<common_chat_msg> msgs{
+        message_user,
+        message_user_parts,
+        message_assist_call,
+        message_assist_call_thoughts,
+        message_assist_call_thoughts_unparsed,
+        message_assist_call_id,
+        message_assist_call_idx,
+        message_assist_call_python,
+        message_assist_call_code_interpreter,
     };
-    auto code_interpreter_message_assist_call = json{
-        { "role",       "assistant"                },
-        { "content",    {}                         },
-        { "tool_calls", json{ {
-                            { "type", "function" },
-                            { "function",
-                              {
-                                  { "name", "code_interpreter" },
-                                  { "arguments",
-                                    {
-                                        { "code", "print('hey')" },
-                                    } },
-                              } },
-                        } } }
+    for (const auto & msg : msgs) {
+        auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
+        auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
+        assert_equals((size_t) 1, msgs2.size());
+        auto msg2 = msgs2[0];
+        assert_msg_equals(msg, msg2);
+    }
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"role\": \"user\",\n"
+            "    \"content\": [\n"
+            "      {\n"
+            "        \"type\": \"text\",\n"
+            "        \"text\": \"Hey\"\n"
+            "      },\n"
+            "      {\n"
+            "        \"type\": \"text\",\n"
+            "        \"text\": \"there\"\n"
+            "      }\n"
+            "    ]\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
+
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"role\": \"assistant\",\n"
+            "    \"content\": null,\n"
+            "    \"tool_calls\": [\n"
+            "      {\n"
+            "        \"type\": \"function\",\n"
+            "        \"function\": {\n"
+            "          \"name\": \"python\",\n"
+            "          \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
+            "        }\n"
+            "      }\n"
+            "    ]\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
+}
+
+static void test_tools_oaicompat_json_conversion() {
+    std::vector<common_chat_tool> tools{
+        special_function_tool,
+        python_tool,
+        code_interpreter_tool,
     };
 
-    common_chat_inputs inputs_no_tools;
-    inputs_no_tools.messages                = json::array({message_user});
+    for (const auto & tool : tools) {
+        auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
+        auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
+        assert_equals((size_t) 1, tools2.size());
+        auto tool2 = tools2[0];
+        assert_equals(tool.name, tool2.name);
+        assert_equals(tool.description, tool2.description);
+        assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
+    }
+
+    assert_equals(
+        std::string(
+            "[\n"
+            "  {\n"
+            "    \"type\": \"function\",\n"
+            "    \"function\": {\n"
+            "      \"name\": \"special_function\",\n"
+            "      \"description\": \"I'm special\",\n"
+            "      \"parameters\": {\n"
+            "        \"type\": \"object\",\n"
+            "        \"properties\": {\n"
+            "          \"arg1\": {\n"
+            "            \"type\": \"integer\",\n"
+            "            \"description\": \"The arg.\"\n"
+            "          }\n"
+            "        },\n"
+            "        \"required\": [\n"
+            "          \"arg1\"\n"
+            "        ]\n"
+            "      }\n"
+            "    }\n"
+            "  }\n"
+            "]"
+        ),
+        common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
+}
+
+static void test_template_output_parsers() {
+
+    common_chat_templates_inputs inputs_no_tools;
+    inputs_no_tools.messages                = {message_user};
     inputs_no_tools.extract_reasoning       = false;
 
-    common_chat_inputs inputs_no_tools_think;
-    inputs_no_tools_think.messages          = json::array({message_user});
+    common_chat_templates_inputs inputs_no_tools_think;
+    inputs_no_tools_think.messages          = {message_user};
     inputs_no_tools_think.extract_reasoning = true;
 
-    common_chat_inputs inputs_tools;
-    inputs_tools.messages                   = json::array({message_user});
-    inputs_tools.tools                      = json::array({special_function_tool});
+    common_chat_templates_inputs inputs_tools;
+    inputs_tools.messages                   = {message_user};
+    inputs_tools.tools                      = {special_function_tool};
     inputs_tools.extract_reasoning          = false;
 
-    common_chat_inputs inputs_tools_think;
-    inputs_tools_think.messages             = json::array({message_user});
-    inputs_tools_think.tools                = json::array({special_function_tool});
+    common_chat_templates_inputs inputs_tools_think;
+    inputs_tools_think.messages             = {message_user};
+    inputs_tools_think.tools                = {special_function_tool};
     inputs_tools_think.extract_reasoning    = true;
 
-    common_chat_inputs inputs_tools_builtin;
-    inputs_tools_builtin.messages           = json::array({message_user});
-    inputs_tools_builtin.tools              = json::array({python_tool});
+    common_chat_templates_inputs inputs_tools_builtin;
+    inputs_tools_builtin.messages           = {message_user};
+    inputs_tools_builtin.tools              = {python_tool};
     inputs_tools_builtin.extract_reasoning  = false;
 
     {
         // Not supported yet
-        const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
-        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
+        auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
+        auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
         std::vector<std::string>   end_tokens{ "<|END_OF_TURN_TOKEN|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_params_init(tmpl, inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_params_init(tmpl, inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
-        assert_msg_equals(msg_from_json(message_assist),
+        assert_msg_equals(message_assist,
             common_chat_parse(
                 "Hello, world!\nWhat's up?",
                 COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(msg_from_json(message_assist),
+        assert_msg_equals(message_assist,
             common_chat_parse(
                 "Hello, world!\nWhat's up?<|END_RESPONSE|>",
                 COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(msg_from_json(message_assist),
+        assert_msg_equals(message_assist,
             common_chat_parse(
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
                 COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
+        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
             common_chat_parse(
                 "<|START_THINKING|>I'm thinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
                 COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
+        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
             common_chat_parse(
                 "<|START_THINKING|>I'm thinking<|END_THINKING|>"
                 "Hello, world!\nWhat's up?<|END_RESPONSE|>",
                 COMMON_CHAT_FORMAT_COMMAND_R7B));
 
-        assert_msg_equals(msg_from_json(message_assist_thoughts),
+        assert_msg_equals(message_assist_thoughts,
             common_chat_parse(
                 "<|START_THINKING|>I'm thinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
                 COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
 
-        test_template(tmpl, end_tokens, message_assist_call_idx, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
                       "<|START_THINKING|><|END_THINKING|>"
                       "<|START_ACTION|>[\n"
                       "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
                       "]<|END_ACTION|>");
-        test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
-                      "<|START_THINKING|>I'm thinking<|END_THINKING|>"
-                      "<|START_ACTION|>[\n"
-                      "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
-                      "]<|END_ACTION|>",
-                      /* expect_grammar_triggered= */ true,
-                      /* test_grammar_if_triggered= */ true,
-                      /* think= */ true);
-        test_template(tmpl, end_tokens, message_assist, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools,
                       "<|START_RESPONSE|>Hello, world!\n"
                       "What's up?<|END_RESPONSE|>",
                       /* expect_grammar_triggered= */ false);
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
+        auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
         std::vector<std::string>   end_tokens{ "<end_of_turn>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_GENERIC,
-                      common_chat_params_init(
-                          common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
-                                               "<s>", "</s>"),
+                      common_chat_templates_apply(
+                          read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
                           inputs_tools)
                           .format);
 
         // Generic tool calls doesn't generate / parse content-only messages symmetrically.
 
-        assert_msg_equals(msg_from_json(message_assist),
+        assert_msg_equals(message_assist,
                           common_chat_parse("{\n"
                                             "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
                                             "}",
-                                            common_chat_params_init(tmpl, inputs_tools).format));
-        test_template(tmpl, end_tokens, message_assist_call_id, tools,
+                                            common_chat_templates_apply(tmpls.get(), inputs_tools).format));
+        test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
                       "{\n"
                       "  \"tool_calls\": [\n"
                       "    {\n"
@@ -550,143 +612,133 @@ static void test_template_output_parsers() {
                       "}");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
         std::vector<std::string>   end_tokens{ "</s>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(
-            tmpl, end_tokens, message_assist_call_id, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(
+            tmpls.get(), end_tokens, message_assist_call_id, tools,
             "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
     }
     {
-        const common_chat_template tmpl(
-            read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
+        auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
         std::vector<std::string> end_tokens{ "<|im_end|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(
             COMMON_CHAT_FORMAT_HERMES_2_PRO,
-            common_chat_params_init(
-                common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
-                                     "<s>", "</s>"),
+            common_chat_templates_apply(
+                read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
                 inputs_tools)
                 .format);
         assert_equals(
             COMMON_CHAT_FORMAT_HERMES_2_PRO,
-            common_chat_params_init(
-                common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
+            common_chat_templates_apply(
+                read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
                 inputs_tools)
                 .format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<tool_call>\n"
                       "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
                       "</tool_call>");
-        test_template(tmpl, end_tokens, python_message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
                       "<tool_call>\n"
                       "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
                       "</tool_call>");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
-                      common_chat_params_init(tmpl, inputs_tools_builtin).format);
+                      common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
-                      common_chat_params_init(
-                          common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
-                                               "<s>", "</s>"),
+                      common_chat_templates_apply(
+                          read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
                           inputs_tools_builtin)
                           .format);
 
-        // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
+        // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
                       "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
-        test_template(tmpl, end_tokens, python_message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
                       "<|python_tag|>python.call(code=\"print('hey')\")");
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
-                      common_chat_params_init(tmpl, inputs_tools).format);
+                      common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<function=special_function>{\"arg1\": 1}</function>");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
-        test_template(tmpl, end_tokens, message_assist, {},
+        test_templates(tmpls.get(), end_tokens, message_assist, {},
                       "all\n"
                       "Hello, world!\n"
                       "What's up?",
                       /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "special_function\n"
                       "{\"arg1\": 1}");
     }
     {
-        const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
-                                        "</s>");
+        auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
         std::vector<std::string>   end_tokens{ "<|eot_id|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
     }
     {
         // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
-        const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
-                                        "<s>", "</s>");
+        auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_params_init(tmpl, inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        assert_msg_equals(message_assist_thoughts_unparsed_think,
             common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
             COMMON_CHAT_FORMAT_DEEPSEEK_R1));
-        assert_msg_equals(msg_from_json(message_assist_thoughts),
+        assert_msg_equals(message_assist_thoughts,
             common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
             COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
-        assert_msg_equals(msg_from_json(message_assist_thoughts),
+        assert_msg_equals(message_assist_thoughts,
             // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
             common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
             COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
-        // test_template(tmpl, end_tokens, message_assist_call, tools,
+        // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
         //               "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
         //               "```json\n"
         //               "{\"arg1\": 1}\n"
@@ -697,23 +749,22 @@ static void test_template_output_parsers() {
     }
     {
         // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
-        const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
-                                        "<s>", "</s>");
+        auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_params_init(tmpl, inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
-        test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
+        test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+        assert_msg_equals(message_assist_thoughts_unparsed_think,
             common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
             COMMON_CHAT_FORMAT_DEEPSEEK_R1));
-        assert_msg_equals(msg_from_json(message_assist_thoughts),
+        assert_msg_equals(message_assist_thoughts,
             common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
             COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
 
-        assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
+        assert_msg_equals(message_assist_call_thoughts_unparsed,
             common_chat_parse(
                 "<think>I'm\nthinking</think>\n\n"
                 "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
@@ -721,7 +772,7 @@ static void test_template_output_parsers() {
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
                 COMMON_CHAT_FORMAT_DEEPSEEK_R1));
-        assert_msg_equals(msg_from_json(message_assist_call_thoughts),
+        assert_msg_equals(message_assist_call_thoughts,
             common_chat_parse(
                 "<think>I'm\nthinking</think>\n\n"
                 "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
@@ -729,7 +780,7 @@ static void test_template_output_parsers() {
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
                 COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
-        test_template(tmpl, end_tokens, message_assist_call, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                 "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
                 "```json\n"
                 "{\"arg1\": 1}\n"
@@ -738,38 +789,46 @@ static void test_template_output_parsers() {
 }
 
 int main(int argc, char ** argv) {
+    try {
 #ifndef _WIN32
-    if (argc > 1) {
-        common_chat_inputs inputs;
-        inputs.messages = {
-            { { "role", "user" }, { "content", "Hey" } }
-        };
-        inputs.tools = json::array({ special_function_tool });
-
-        std::cout << "| Template | Format |\n";
-        std::cout << "|----------|--------|\n";
-
-        for (int i = 1; i < argc; i++) {
-            try {
-                std::string path = argv[i];
-                if (path.rfind(".jinja") != path.size() - 6) {
-                    std::cerr << "Skipping non-jinja file: " << path << std::endl;
-                    continue;
+        if (argc > 1) {
+            common_chat_templates_inputs inputs;
+            common_chat_msg msg;
+            msg.role = "user";
+            msg.content = "Hey";
+            inputs.messages = {msg};
+            inputs.tools = { special_function_tool };
+
+            std::cout << "| Template | Format |\n";
+            std::cout << "|----------|--------|\n";
+
+            for (int i = 1; i < argc; i++) {
+                try {
+                    std::string path = argv[i];
+                    if (path.rfind(".jinja") != path.size() - 6) {
+                        std::cerr << "Skipping non-jinja file: " << path << '\n';
+                        continue;
+                    }
+                    auto tmpls = read_templates(path);
+                    auto parts  = string_split(path, "/");
+                    auto name   = parts[parts.size() - 1];
+                    auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
+                    std::cout << "| " << name << " | " << format << " |\n";
+                } catch (const std::exception & e) {
+                    std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
                 }
-                common_chat_template tmpl(read_file(path), "", "");
-                auto parts  = string_split(path, "/");
-                auto name   = parts[parts.size() - 1];
-                auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
-                std::cout << "| " << name << " | " << format << " |\n";
-            } catch (const std::exception & e) {
-                std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
             }
-        }
-    } else
+        } else
 #endif
-    {
-        test_template_output_parsers();
-        std::cout << "\n[chat] All tests passed!" << std::endl;
+        {
+            test_msgs_oaicompat_json_conversion();
+            test_tools_oaicompat_json_conversion();
+            test_template_output_parsers();
+            std::cout << "\n[chat] All tests passed!" << '\n';
+        }
+        return 0;
+    } catch (const std::exception & e) {
+        std::cerr << "Error: " << e.what() << '\n';
+        return 1;
     }
-    return 0;
 }