From: Olivier Chafik Date: Tue, 18 Feb 2025 18:03:23 +0000 (+0000) Subject: tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900) X-Git-Tag: upstream/0.0.4853~114 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=63e489c025d61c7ca5ec06c5d10f36e2b76aaa1d;p=pkg%2Fggml%2Fsources%2Fllama.cpp tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900) * tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type * addressed clang-tidy lints in [test-]chat.* * rm minja deps from util & common & move it to common/minja/ * add name & tool_call_id to common_chat_msg * add common_chat_tool * added json <-> tools, msgs conversions to chat.h * fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens) * fix deepseek r1 slow test (no longer opening w/ new template) * allow empty tools w/ auto + grammar * fix & test server grammar & json_schema params w/ & w/o --jinja --- diff --git a/Makefile b/Makefile index 66219408..fb9a3b44 100644 --- a/Makefile +++ b/Makefile @@ -1364,7 +1364,7 @@ llama-server: \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ common/chat.cpp \ - common/chat.hpp \ + common/chat.h \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c2b4aa7d..17146fff 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -57,8 +57,7 @@ add_library(${TARGET} STATIC arg.h base64.hpp chat.cpp - chat.hpp - chat-template.hpp + chat.h common.cpp common.h console.cpp @@ -68,7 +67,8 @@ add_library(${TARGET} STATIC llguidance.cpp log.cpp log.h - minja.hpp + minja/chat-template.hpp + minja/minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index f06aa107..eb8becca 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,6 +2,7 @@ #include "log.h" #include "sampling.h" +#include "chat.h" #include #include diff --git a/common/chat-template.hpp b/common/chat-template.hpp deleted file mode 100644 index 882ba41b..00000000 --- a/common/chat-template.hpp +++ /dev/null @@ -1,529 +0,0 @@ -/* - Copyright 2024 Google LLC - - Use of this source code is governed by an MIT-style - license that can be found in the LICENSE file or at - https://opensource.org/licenses/MIT. -*/ -// SPDX-License-Identifier: MIT -#pragma once - -#include "minja.hpp" -#include -#include -#include - -using json = nlohmann::ordered_json; - -namespace minja { - -struct chat_template_caps { - bool supports_tools = false; - bool supports_tool_calls = false; - bool supports_tool_responses = false; - bool supports_system_role = false; - bool supports_parallel_tool_calls = false; - bool supports_tool_call_id = false; - // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. - // Most other templates (and OpenAI's API) expect the arguments object to be stringified. - bool requires_object_arguments = false; - // CohereForAI/c4ai-command-r-plus simple variant - bool requires_non_null_content = false; - // MiniMaxAI/MiniMax-Text-01 special - bool requires_typed_content = false; -}; - -struct chat_template_inputs { - nlohmann::ordered_json messages; - nlohmann::ordered_json tools; - bool add_generation_prompt = true; - nlohmann::ordered_json extra_context; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); -}; - -struct chat_template_options { - bool apply_polyfills = true; - bool use_bos_token = true; - bool use_eos_token = true; - bool define_strftime_now = true; - - bool polyfill_tools = true; - bool polyfill_tool_call_examples = true; - bool polyfill_tool_calls = true; - bool polyfill_tool_responses = true; - bool polyfill_system_role = true; - bool polyfill_object_arguments = true; - bool polyfill_typed_content = true; -}; - -class chat_template { - - private: - chat_template_caps caps_; - std::string source_; - std::string bos_token_; - std::string eos_token_; - std::shared_ptr template_root_; - std::string tool_call_example_; - - std::string try_raw_render( - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - bool add_generation_prompt, - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const - { - try { - chat_template_inputs inputs; - inputs.messages = messages; - inputs.tools = tools; - inputs.add_generation_prompt = add_generation_prompt; - inputs.extra_context = extra_context; - // Use fixed date for tests - inputs.now = std::chrono::system_clock::from_time_t(0); - - chat_template_options opts; - opts.apply_polyfills = false; - - auto prompt = apply(inputs, opts); - // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); - return prompt; - } catch (const std::exception & e) { - // fprintf(stderr, "try_raw_render error: %s\n", e.what()); - return ""; - } - } - - public: - - chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) - : source_(source), bos_token_(bos_token), eos_token_(eos_token) - { - template_root_ = minja::Parser::parse(source_, { - /* .trim_blocks = */ true, - /* .lstrip_blocks = */ true, - /* .keep_trailing_newline = */ false, - }); - - auto contains = [](const std::string & haystack, const std::string & needle) { - return haystack.find(needle) != std::string::npos; - }; - - const std::string user_needle = ""; - const std::string sys_needle = ""; - const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}}; - const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}}; - - caps_.requires_typed_content = - !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle) - && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle); - - const auto dummy_user_msg = caps_.requires_typed_content - ? dummy_typed_user_msg - : dummy_str_user_msg; - const json needle_system_msg = { - {"role", "system"}, - {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)}, - }; - - caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle); - - auto out = try_raw_render(json::array({ - dummy_user_msg - }), json::array({ - { - {"name", "some_tool"}, - {"type", "function"}, - {"function", { - {"name", "some_tool"}, - {"description", "Some tool."}, - {"parameters", { - {"type", "object"}, - {"properties", { - {"arg", { - {"type", "string"}, - {"description", "Some argument."}, - }}, - }}, - {"required", json::array({ "arg" })}, - }}, - }}, - }, - }), false); - caps_.supports_tools = contains(out, "some_tool"); - - auto make_tool_calls_msg = [&](const json & tool_calls) { - return json { - {"role", "assistant"}, - {"content", nullptr}, - {"tool_calls", tool_calls}, - }; - }; - auto make_tool_call = [](const std::string & tool_name, const json & arguments) { - return json { - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"arguments", arguments}, - {"name", tool_name}, - }}, - }; - }; - const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; - - // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. - out = try_raw_render(json::array({ - dummy_user_msg, - make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), - }), {}, false); - auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); - out = try_raw_render(json::array({ - dummy_user_msg, - make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), - }), {}, false); - auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); - - caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; - caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); - caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); - - if (caps_.supports_tool_calls) { - auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); - auto tc1 = make_tool_call("test_tool1", dummy_args); - auto tc2 = make_tool_call("test_tool2", dummy_args); - auto out = try_raw_render(json::array({ - dummy_user_msg, - make_tool_calls_msg(json::array({tc1, tc2})), - }), {}, false); - caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2"); - - out = try_raw_render(json::array({ - dummy_user_msg, - make_tool_calls_msg(json::array({tc1})), - { - {"role", "tool"}, - {"name", "test_tool1"}, - {"content", "Some response!"}, - {"tool_call_id", "call_911_"}, - } - }), {}, false); - caps_.supports_tool_responses = contains(out, "Some response!"); - caps_.supports_tool_call_id = contains(out, "call_911_"); - } - - try { - if (!caps_.supports_tools) { - const json user_msg { - {"role", "user"}, - {"content", "Hey"}, - }; - const json args { - {"arg1", "some_value"}, - }; - const json tool_call_msg { - {"role", "assistant"}, - {"content", nullptr}, - {"tool_calls", json::array({ - { - // TODO: detect if requires numerical id or fixed length == 6 like Nemo - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"name", "tool_name"}, - {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, - }}, - }, - })}, - }; - std::string prefix, full; - { - chat_template_inputs inputs; - inputs.messages = json::array({user_msg}); - inputs.add_generation_prompt = true; - prefix = apply(inputs); - } - { - chat_template_inputs inputs; - inputs.messages = json::array({user_msg, tool_call_msg}); - inputs.add_generation_prompt = false; - full = apply(inputs); - } - auto eos_pos_last = full.rfind(eos_token_); - if (eos_pos_last == prefix.size() - eos_token_.size() || - (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { - full = full.substr(0, eos_pos_last); - } - size_t common_prefix_length = 0; - for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { - if (prefix[i] != full[i]) { - break; - } - if (prefix[i] == '<') { - // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, - // but it removes thinking tags for past messages. - // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. - continue; - } - common_prefix_length = i + 1; - } - auto example = full.substr(common_prefix_length); - if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { - fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); - } else { - tool_call_example_ = example; - } - } - } catch (const std::exception & e) { - fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); - } - } - - const std::string & source() const { return source_; } - const std::string & bos_token() const { return bos_token_; } - const std::string & eos_token() const { return eos_token_; } - const chat_template_caps & original_caps() const { return caps_; } - - // Deprecated, please use the form with chat_template_inputs and chat_template_options - std::string apply( - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - bool add_generation_prompt, - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool apply_polyfills = true) - { - fprintf(stderr, "[%s] Deprecated!\n", __func__); - chat_template_inputs inputs; - inputs.messages = messages; - inputs.tools = tools; - inputs.add_generation_prompt = add_generation_prompt; - inputs.extra_context = extra_context; - inputs.now = std::chrono::system_clock::now(); - - chat_template_options opts; - opts.apply_polyfills = apply_polyfills; - - return apply(inputs, opts); - } - - std::string apply( - const chat_template_inputs & inputs, - const chat_template_options & opts = chat_template_options()) const - { - json actual_messages; - - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); - auto has_tool_calls = false; - auto has_tool_responses = false; - auto has_string_content = false; - for (const auto & message : inputs.messages) { - if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { - has_tool_calls = true; - } - if (message.contains("role") && message["role"] == "tool") { - has_tool_responses = true; - } - if (message.contains("content") && message["content"].is_string()) { - has_string_content = true; - } - } - - auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; - auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; - auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; - auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; - auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; - auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; - auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; - - auto needs_polyfills = opts.apply_polyfills && (false - || polyfill_system_role - || polyfill_tools - || polyfill_tool_calls - || polyfill_tool_responses - || polyfill_object_arguments - || polyfill_typed_content - ); - - if (needs_polyfills) { - actual_messages = json::array(); - - auto add_message = [&](const json & msg) { - if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { - actual_messages.push_back({ - {"role", msg.at("role")}, - {"content", {{ - {"type", "text"}, - {"text", msg.at("content")}, - }}}, - }); - } else { - actual_messages.push_back(msg); - } - }; - - std::string pending_system; - auto flush_sys = [&]() { - if (!pending_system.empty()) { - add_message({ - {"role", "user"}, - {"content", pending_system}, - }); - pending_system.clear(); - } - }; - - json adjusted_messages; - if (polyfill_tools) { - adjusted_messages = add_system(inputs.messages, - "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + - (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); - } else { - adjusted_messages = inputs.messages; - } - - for (const auto & message_ : adjusted_messages) { - auto message = message_; - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); - } - std::string role = message.at("role"); - - if (message.contains("tool_calls")) { - if (polyfill_object_arguments || polyfill_tool_calls) { - for (auto & tool_call : message.at("tool_calls")) { - if (tool_call["type"] == "function") { - auto & function = tool_call.at("function"); - auto & arguments = function.at("arguments"); - if (arguments.is_string()) { - try { - arguments = json::parse(arguments.get()); - } catch (const std::exception & ecvt) { - fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); - } - } - } - } - } - if (polyfill_tool_calls) { - auto content = message.at("content"); - auto tool_calls = json::array(); - for (const auto & tool_call : message.at("tool_calls")) { - if (tool_call.at("type") != "function") { - continue; - } - const auto & function = tool_call.at("function"); - auto tc = json { - {"name", function.at("name")}, - {"arguments", function.at("arguments")}, - }; - if (tool_call.contains("id")) { - tc["id"] = tool_call["id"]; - } - tool_calls.push_back(tc); - } - auto obj = json { - {"tool_calls", tool_calls}, - }; - if (!content.is_null() && content != "") { - obj["content"] = content; - } - message["content"] = obj.dump(2); - message.erase("tool_calls"); - } - } - if (polyfill_tool_responses && role == "tool") { - message["role"] = "user"; - auto obj = json { - {"tool_response", { - {"content", message.at("content")}, - }}, - }; - if (message.contains("name")) { - obj["tool_response"]["name"] = message.at("name"); - } - if (message.contains("tool_call_id")) { - obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); - } - message["content"] = obj.dump(2); - message.erase("name"); - } - - if (!message["content"].is_null() && polyfill_system_role) { - std::string content = message.at("content"); - if (role == "system") { - if (!pending_system.empty()) pending_system += "\n"; - pending_system += content; - continue; - } else { - if (role == "user") { - if (!pending_system.empty()) { - message["content"] = pending_system + (content.empty() ? "" : "\n" + content); - pending_system.clear(); - } - } else { - flush_sys(); - } - } - } - add_message(message); - } - flush_sys(); - } else { - actual_messages = inputs.messages; - } - - auto context = minja::Context::make(json({ - {"messages", actual_messages}, - {"add_generation_prompt", inputs.add_generation_prompt}, - })); - context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); - context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); - if (opts.define_strftime_now) { - auto now = inputs.now; - context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { - args.expectArgs("strftime_now", {1, 1}, {0, 0}); - auto format = args.args[0].get(); - - auto time = std::chrono::system_clock::to_time_t(now); - auto local_time = *std::localtime(&time); - std::ostringstream ss; - ss << std::put_time(&local_time, format.c_str()); - return ss.str(); - })); - } - if (!inputs.tools.is_null()) { - context->set("tools", minja::Value(inputs.tools)); - } - if (!inputs.extra_context.is_null()) { - for (auto & kv : inputs.extra_context.items()) { - context->set(kv.key(), minja::Value(kv.value())); - } - } - - auto ret = template_root_->render(context); - // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str()); - // fprintf(stderr, "apply: %s\n\n", ret.c_str()); - return ret; - } - - static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { - json messages_with_system = messages; - - if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { - std::string existing_system = messages_with_system.at(0).at("content"); - messages_with_system[0] = json { - {"role", "system"}, - {"content", existing_system + "\n\n" + system_prompt}, - }; - } else { - messages_with_system.insert(messages_with_system.begin(), json { - {"role", "system"}, - {"content", system_prompt}, - }); - } - return messages_with_system; - } -}; - -} // namespace minja diff --git a/common/chat.cpp b/common/chat.cpp index f21a9d2a..9ebe4c57 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,8 +1,433 @@ -#include "chat.hpp" -#include "chat-template.hpp" +#include "chat.h" #include "json-schema-to-grammar.h" #include "log.h" -#include "minja.hpp" +#include "minja/chat-template.hpp" +#include "minja/minja.hpp" + +#include + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; + bool extract_reasoning = true; +}; + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { + if (tool_choice == "auto") { + return COMMON_CHAT_TOOL_CHOICE_AUTO; + } + if (tool_choice == "none") { + return COMMON_CHAT_TOOL_CHOICE_NONE; + } + if (tool_choice == "required") { + return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + } + throw std::runtime_error("Invalid tool_choice: " + tool_choice); +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const json & messages) { + std::vector msgs; + + try { + + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + } + + for (const auto & message : messages) { + if (!message.is_object()) { + throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + } + + common_chat_msg msg; + if (!message.contains("role")) { + throw std::runtime_error("Missing 'role' in message: " + message.dump()); + } + msg.role = message.at("role"); + + if (message.contains("content")) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) { + throw std::runtime_error("Missing content part type: " + part.dump()); + } + const auto & type = part.at("type"); + if (type != "text") { + throw std::runtime_error("Unsupported content part type: " + type.dump()); + } + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); + } + } else if (!content.is_null()) { + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } + if (message.contains("tool_calls")) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + if (!tool_call.contains("type")) { + throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + } + const auto & type = tool_call.at("type"); + if (type != "function") { + throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + } + if (!tool_call.contains("function")) { + throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + } + const auto & fc = tool_call.at("function"); + if (!fc.contains("name")) { + throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + } + tc.name = fc.at("name"); + tc.arguments = fc.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } + } + + msgs.push_back(msg); + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); + } + + return msgs; +} + +template <> +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { + json messages = json::array(); + for (const auto & msg : msgs) { + if (!msg.content.empty() && !msg.content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + json jmsg { + {"role", msg.role}, + }; + if (!msg.content.empty()) { + jmsg["content"] = msg.content; + } else if (!msg.content_parts.empty()) { + if (concat_typed_text) { + std::string text; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); + continue; + } + if (!text.empty()) { + text += '\n'; + } + text += part.text; + } + jmsg["content"] = text; + } else { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } + } else { + jmsg["content"] = json(); // null + } + if (!msg.reasoning_content.empty()) { + jmsg["reasoning_content"] = msg.reasoning_content; + } + if (!msg.tool_name.empty()) { + jmsg["name"] = msg.tool_name; + } + if (!msg.tool_call_id.empty()) { + jmsg["tool_call_id"] = msg.tool_call_id; + } + if (!msg.tool_calls.empty()) { + auto & tool_calls = jmsg["tool_calls"] = json::array(); + for (const auto & tool_call : msg.tool_calls) { + json tc { + {"type", "function"}, + {"function", { + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, + }}, + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + tool_calls.push_back(tc); + } + } + messages.push_back(jmsg); + } + return messages; +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { + return common_chat_msgs_parse_oaicompat(json::parse(messages)); +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) { + throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + } + for (const auto & tool : tools) { + if (!tool.contains("type")) { + throw std::runtime_error("Missing tool type: " + tool.dump()); + } + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") { + throw std::runtime_error("Unsupported tool type: " + tool.dump()); + } + if (!tool.contains("function")) { + throw std::runtime_error("Missing tool function: " + tool.dump()); + } + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { + return common_chat_tools_parse_oaicompat(json::parse(tools)); +} + +template <> +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.name}, + {"description", tool.description}, + {"parameters", json::parse(tool.parameters)}, + }}, + }); + } + return result; +} + +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + + common_chat_templates_inputs inputs; + inputs.messages = {msg}; + + common_chat_templates_apply(tmpls.get(), inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} + +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { + + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + + std::string fmt_past_msg; + if (!past_msg.empty()) { + inputs.messages = past_msg; + inputs.add_generation_prompt = false; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + } + std::ostringstream ss; + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + inputs.messages.push_back(new_msg); + inputs.add_generation_prompt = add_ass; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} + +std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + auto add_simple_msg = [&](auto role, auto content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + inputs.messages.push_back(msg); + }; + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); + add_simple_msg("assistant", "Hi there"); + add_simple_msg("user", "How are you?"); + return common_chat_templates_apply(tmpls, inputs).prompt; +} + +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + +void common_chat_templates_free(struct common_chat_templates * tmpls) { + delete tmpls; +} + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) { + return tmpls->has_explicit_template; +} + +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) { + if (variant != nullptr) { + if (strcmp(variant, "tool_use") == 0) { + if (tmpls->template_tool_use) { + return tmpls->template_tool_use->source().c_str(); + } + return nullptr; + } else { + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant); + } + } + return tmpls->template_default->source().c_str(); +} + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) +{ + std::string default_template_src; + std::string template_tool_use_src; + + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + GGML_ASSERT(model != nullptr); + const auto * str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } else { + default_template_src = chat_template_override; + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = CHATML_TEMPLATE_SRC; + } + } + std::string token_bos = bos_token_override; + std::string token_eos = eos_token_override; + if (model) { + const auto * vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); + } + return std::string(); + } + return common_token_to_piece(vocab, token, true); + }; + token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + } + common_chat_templates_ptr tmpls(new common_chat_templates()); + tmpls->has_explicit_template = has_explicit_template; + try { + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + } + if (!template_tool_use_src.empty()) { + try { + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); + } + } + return tmpls; +} std::string common_chat_format_name(common_chat_format format) { switch (format) { @@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons json_error_locator() : position(0), found_error(false) {} - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT this->position = position - 1; this->found_error = true; return false; } - bool null() override { return true; } - bool boolean(bool) override { return true; } - bool number_integer(number_integer_t) override { return true; } - bool number_unsigned(number_unsigned_t) override { return true; } - bool number_float(number_float_t, const string_t &) override { return true; } - bool string(string_t &) override { return true; } - bool binary(binary_t &) override { return true; } - bool start_object(std::size_t) override { return true; } - bool key(string_t &) override { return true; } + bool null() override { return true; } // NOLINT + bool boolean(bool) override { return true; } // NOLINT + bool number_integer(number_integer_t) override { return true; } // NOLINT + bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT + bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT + bool string(string_t &) override { return true; } // NOLINT + bool binary(binary_t &) override { return true; } // NOLINT + bool start_object(std::size_t) override { return true; } // NOLINT + bool key(string_t &) override { return true; } // NOLINT bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } + bool start_array(std::size_t) override { return true; } // NOLINT bool end_array() override { return true; } }; json_error_locator err_loc; @@ -187,13 +612,20 @@ static std::string apply( // tmpl_inputs.now = std::chrono::system_clock::now(); minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - return tmpl.apply(tmpl_inputs, tmpl_opts); + // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens + // instead of using `chat_template_options.use_bos_token = false`, since these tokens + // may be needed inside the template / between messages too. + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + if (string_starts_with(result, tmpl.bos_token())) { + result = result.substr(tmpl.bos_token().size()); + } + if (string_ends_with(result, tmpl.eos_token())) { + result = result.substr(0, result.size() - tmpl.eos_token().size()); + } + return result; } -static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto tool_call_schemas = json::array(); @@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp {"required", json::array({"tool_call"})}, }; const auto schema = - inputs.tool_choice != "required" + inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED ? json { {"anyOf", json::array({ tool_call, @@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) { return result; } -static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } -static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame const auto & parameters_required = parameters.at("required"); for (const auto & prop : expected_properties) { if (!parameters_properties.contains(prop)) { - throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT } if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { - throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT } } if (parameters_properties.size() != expected_properties.size()) { @@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha") { + if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py expect_tool_parameters(name, parameters, {"query"}); } else if (name == "python" || name == "code_interpreter") { @@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT } tool_rules.push_back( @@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo auto arg_value_str = raw_args.substr(it_eq + 1); auto arg_value = json::parse(arg_value_str); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ match[1], - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }, - }, - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ name, + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }); + return msg; } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" @@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, return msg; } -static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - fprintf(stderr, "%s\n", __func__); +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { + LOG_DBG("%s\n", __func__); common_chat_params data; data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); } -static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; @@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); @@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in } } -static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_params data; json tools = inputs.tools.is_null() ? inputs.tools : json::array(); std::string python_code_argument_name; auto has_raw_python = false; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con throw std::runtime_error("Missing type in python tool"); } has_raw_python = true; - auto type = parameters.at("type"); + const auto & type = parameters.at("type"); if (type == "object") { auto properties = parameters.at("properties"); for (auto it = properties.begin(); it != properties.end(); ++it) { @@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { auto code = match[1].str(); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ (json {{"code", code}}).dump(), - /* .id = */ "", - }, - } - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ "python", + /* .arguments = */ (json {{"code", code}}).dump(), + /* .id = */ "", + }); + return msg; } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); @@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + common_chat_msg msg; + msg.role = "assistant"; + auto end = input.end(); std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + msg.content = input; + return msg; } - common_chat_msg result; - result.role = "assistant"; - result.content = rit->prefix(); + msg.content = rit->prefix(); auto it = rit->suffix().first; while (it != end) { @@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) throw std::runtime_error("Failed to parse json tool call"); } const auto & arguments = call.at("arguments"); - result.tool_calls.push_back({ + msg.tool_calls.push_back({ call.at("name"), arguments.dump(), // arguments.is_string() ? arguments.get() : arguments.dump(), @@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) break; } } - return result; + return msg; } catch (const std::exception & e) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } } -static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; @@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } -common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_templates_apply_jinja( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + templates_params params; + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; const auto & src = tmpl.source(); const auto & caps = tmpl.original_caps(); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.add_generation_prompt = inputs.add_generation_prompt; + params.extract_reasoning = inputs.extract_reasoning; + params.tool_choice = inputs.tool_choice; + params.grammar = inputs.grammar; + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } - if (inputs.tools.is_array()) { - if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + params.parallel_tool_calls = false; + } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + } + + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } if (caps.supports_tool_calls && !caps.supports_tools) { @@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co } // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, params); } // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_command_r7b(tmpl, inputs); + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, params); } // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. - if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, inputs); + if ((params.tools.is_array() && params.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, params); } // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. if (src.find(">>>all") != std::string::npos) { - return common_chat_params_init_functionary_v3_2(tmpl, inputs); + return common_chat_params_init_functionary_v3_2(tmpl, params); } // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, inputs); + return common_chat_params_init_firefunction_v2(tmpl, params); } // Plain handler (no tools) - if (inputs.tools.is_null() || inputs.tool_choice == "none") { - return common_chat_params_init_without_tools(tmpl, inputs); + if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return common_chat_params_init_without_tools(tmpl, params); } // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos) { - return common_chat_params_init_hermes_2_pro(tmpl, inputs); + return common_chat_params_init_hermes_2_pro(tmpl, params); } // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); + return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); } // Mistral Nemo (w/ tools) if (src.find("[TOOL_CALLS]") != std::string::npos) { - return common_chat_params_init_mistral_nemo(tmpl, inputs); + return common_chat_params_init_mistral_nemo(tmpl, params); } // Generic fallback - return common_chat_params_init_generic(tmpl, inputs); + return common_chat_params_init_generic(tmpl, params); +} + +// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. +static common_chat_params common_chat_templates_apply_legacy( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + int alloc_size = 0; + std::vector chat; + std::vector contents; + for (const auto & msg : inputs.messages) { + auto content = msg.content; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str()); + continue; + } + if (!content.empty()) { + content += "\n";; + } + content += part.text; + } + contents.emplace_back(std::move(content)); + } + for (size_t i = 0; i < contents.size(); ++i) { + const auto & msg = inputs.messages[i]; + const auto & content = contents[i]; + chat.push_back({msg.role.c_str(), content.c_str()}); + alloc_size += (msg.role.size() + content.size()) * 1.25; + } + + std::vector buf(alloc_size); + + // run the first time to get the total output length + const auto & src = tmpls->template_default->source(); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + } + + common_chat_params params; + params.prompt = std::string(buf.data(), res); + if (!inputs.json_schema.empty()) { + params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); + } else { + params.grammar = inputs.grammar; + } + return params; +} + +common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + GGML_ASSERT(tmpls != nullptr); + return inputs.use_jinja + ? common_chat_templates_apply_jinja(tmpls, inputs) + : common_chat_templates_apply_legacy(tmpls, inputs); } static common_chat_msg common_chat_parse_content_only(const std::string & input) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { diff --git a/common/chat.h b/common/chat.h new file mode 100644 index 00000000..e77bef82 --- /dev/null +++ b/common/chat.h @@ -0,0 +1,134 @@ +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. + +#pragma once + +#include "common.h" +#include +#include + +struct common_chat_templates; + +struct common_chat_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + +struct common_chat_msg_content_part { + std::string type; + std::string text; +}; + +struct common_chat_msg { + std::string role; + std::string content; + std::vector content_parts = {}; + std::vector tool_calls = {}; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; +}; + +struct common_chat_tool { + std::string name; + std::string description; + std::string parameters; +}; + +enum common_chat_tool_choice { + COMMON_CHAT_TOOL_CHOICE_AUTO, + COMMON_CHAT_TOOL_CHOICE_REQUIRED, + COMMON_CHAT_TOOL_CHOICE_NONE, +}; + +enum common_chat_format { + COMMON_CHAT_FORMAT_CONTENT_ONLY, + COMMON_CHAT_FORMAT_GENERIC, + COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_LLAMA_3_X, + COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, + COMMON_CHAT_FORMAT_FIREFUNCTION_V2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_HERMES_2_PRO, + COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, + + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats +}; + +struct common_chat_templates_inputs { + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; + // Parameters below only supported when use_jinja is true + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + bool extract_reasoning = true; +}; + +struct common_chat_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::string prompt; + std::string grammar; + bool grammar_lazy = false; + std::vector grammar_triggers; + std::vector preserved_tokens; + std::vector additional_stops; +}; + +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +void common_chat_templates_free(struct common_chat_templates * tmpls); + +struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; + +typedef std::unique_ptr common_chat_templates_ptr; + +common_chat_templates_ptr common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); + + +struct common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); + +// Format single message, while taking into account the position of that message in chat history +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); + +// Returns an example of formatted chat +std::string common_chat_format_example( + const struct common_chat_templates * tmpls, + bool use_jinja); + +std::string common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); + +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); + +// Parses a JSON array of messages in OpenAI's chat completion API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_msgs_parse_oaicompat(const T & messages); +template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); + +// Parses a JSON array of tools in OpenAI's chat completion tool call API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_tools_parse_oaicompat(const T & tools); +template T common_chat_tools_to_json_oaicompat(const std::vector & tools); diff --git a/common/chat.hpp b/common/chat.hpp deleted file mode 100644 index ba1632f6..00000000 --- a/common/chat.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. - -#pragma once - -#include "common.h" -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -struct common_chat_inputs { - json messages; - json tools; - json tool_choice; - json json_schema; - bool parallel_tool_calls; - bool stream; - std::string grammar; - bool add_generation_prompt = true; - bool extract_reasoning = true; -}; - -enum common_chat_format { - COMMON_CHAT_FORMAT_CONTENT_ONLY, - COMMON_CHAT_FORMAT_GENERIC, - COMMON_CHAT_FORMAT_MISTRAL_NEMO, - COMMON_CHAT_FORMAT_LLAMA_3_X, - COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, - COMMON_CHAT_FORMAT_FIREFUNCTION_V2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, - - COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats -}; - -struct common_chat_params { - common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - json prompt; - std::string grammar; - bool grammar_lazy = false; - std::vector grammar_triggers; - std::vector preserved_tokens; - std::vector additional_stops; -}; - -struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); -std::string common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); diff --git a/common/common.cpp b/common/common.cpp index 8661e164..d2b0d50e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,8 +12,6 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" -#include "chat.hpp" -#include "chat-template.hpp" #include #include @@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// Chat template utils -// - -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { - if (use_jinja) { - try { - auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - common_chat_params_init(chat_template, inputs); - return true; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); - return false; - } - } - llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & msgs, - bool add_ass, - bool use_jinja) { - if (use_jinja) { - auto messages = json::array(); - for (const auto & msg : msgs) { - messages.push_back({{"role", msg.role}, {"content", msg.content}}); - } - common_chat_inputs inputs; - inputs.messages = messages; - inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; - } - - int alloc_size = 0; - std::vector chat; - for (const auto & msg : msgs) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; - } - - std::vector buf(alloc_size); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - - // error: chat template is not supported - if (res < 0) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - } - - std::string formatted_chat(buf.data(), res); - return formatted_chat; -} - -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja) { - std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); - std::vector chat_new(past_msg); - // if the past_msg ends with a newline, we must preserve it in the formatted version - if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { - ss << "\n"; - }; - // format chat with new_msg - chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); - // get the diff part - ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return ss.str(); -} - -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { - std::vector msgs = { - {"system", "You are a helpful assistant", {}}, - {"user", "Hello", {}}, - {"assistant", "Hi there", {}}, - {"user", "How are you?", {}}, - }; - return common_chat_apply_template(tmpl, msgs, true, use_jinja); -} - -#define CHATML_TEMPLATE_SRC \ - "{%- for message in messages -%}\n" \ - " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ - "{%- endfor -%}\n" \ - "{%- if add_generation_prompt -%}\n" \ - " {{- '<|im_start|>assistant\n' -}}\n" \ - "{%- endif -%}" - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) -{ - std::string default_template_src; - std::string template_tool_use_src; - - bool has_explicit_template = !chat_template_override.empty(); - if (chat_template_override.empty()) { - auto str = llama_model_chat_template(model, /* name */ nullptr); - if (str) { - default_template_src = str; - has_explicit_template = true; - } - str = llama_model_chat_template(model, /* name */ "tool_use"); - if (str) { - template_tool_use_src = str; - has_explicit_template = true; - } - } else { - default_template_src = chat_template_override; - } - if (default_template_src.empty() || default_template_src == "chatml") { - if (!template_tool_use_src.empty()) { - default_template_src = template_tool_use_src; - } else { - default_template_src = CHATML_TEMPLATE_SRC; - } - } - auto vocab = llama_model_get_vocab(model); - const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { - if (token == LLAMA_TOKEN_NULL) { - if (default_template_src.find(jinja_variable_name) != std::string::npos - || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); - } - return std::string(); - } else { - return common_token_to_piece(vocab, token, true); - } - }; - auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - try { - return { - has_explicit_template, - std::make_unique(default_template_src, token_bos, token_eos), - template_tool_use_src.empty() - ? nullptr - : std::make_unique(template_tool_use_src, token_bos, token_eos), - }; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what()); - return { - has_explicit_template, - std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos), - nullptr, - }; - } -} - // // KV cache utils // diff --git a/common/common.h b/common/common.h index 98b9a446..10bcc10d 100644 --- a/common/common.h +++ b/common/common.h @@ -616,62 +616,6 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); -// -// Chat template utils -// - -struct common_tool_call { - std::string name; - std::string arguments; - std::string id; -}; - -// same with llama_chat_message, but uses std::string -struct common_chat_msg { - std::string role; - std::string content; - std::vector tool_calls; - std::string reasoning_content = ""; -}; - -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); - -namespace minja { - class chat_template; -} - -typedef minja::chat_template common_chat_template; - -struct common_chat_templates { - bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) - std::unique_ptr template_tool_use; -}; - -// CPP wrapper for llama_chat_apply_template -// If the built-in template is not supported, we default to chatml -// If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & chat, - bool add_ass, - bool use_jinja); - -// Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja); - -// Returns an example of formatted chat -std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); - // // KV cache utils // diff --git a/common/minja.hpp b/common/minja.hpp deleted file mode 100644 index c58dd66e..00000000 --- a/common/minja.hpp +++ /dev/null @@ -1,2883 +0,0 @@ -/* - Copyright 2024 Google LLC - - Use of this source code is governed by an MIT-style - license that can be found in the LICENSE file or at - https://opensource.org/licenses/MIT. -*/ -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -namespace minja { - -class Context; - -struct Options { - bool trim_blocks; // removes the first newline after a block - bool lstrip_blocks; // removes leading whitespace on the line of the block - bool keep_trailing_newline; // don't remove last newline -}; - -struct ArgumentsValue; - -inline std::string normalize_newlines(const std::string & s) { -#ifdef _WIN32 - static const std::regex nl_regex("\r\n"); - return std::regex_replace(s, nl_regex, "\n"); -#else - return s; -#endif -} - -/* Values that behave roughly like in Python. */ -class Value : public std::enable_shared_from_this { -public: - using CallableType = std::function &, ArgumentsValue &)>; - using FilterType = std::function &, ArgumentsValue &)>; - -private: - using ObjectType = nlohmann::ordered_map; // Only contains primitive keys - using ArrayType = std::vector; - - std::shared_ptr array_; - std::shared_ptr object_; - std::shared_ptr callable_; - json primitive_; - - Value(const std::shared_ptr & array) : array_(array) {} - Value(const std::shared_ptr & object) : object_(object) {} - Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} - - /* Python-style string repr */ - static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { - if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); - auto s = primitive.dump(); - if (string_quote == '"' || s.find('\'') != std::string::npos) { - out << s; - return; - } - // Reuse json dump, just changing string quotes - out << string_quote; - for (size_t i = 1, n = s.size() - 1; i < n; ++i) { - if (s[i] == '\\' && s[i + 1] == '"') { - out << '"'; - i++; - } else if (s[i] == string_quote) { - out << '\\' << string_quote; - } else { - out << s[i]; - } - } - out << string_quote; - } - void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { - auto print_indent = [&](int level) { - if (indent > 0) { - out << "\n"; - for (int i = 0, n = level * indent; i < n; ++i) out << ' '; - } - }; - auto print_sub_sep = [&]() { - out << ','; - if (indent < 0) out << ' '; - else print_indent(level + 1); - }; - - auto string_quote = to_json ? '"' : '\''; - - if (is_null()) out << "null"; - else if (array_) { - out << "["; - print_indent(level + 1); - for (size_t i = 0; i < array_->size(); ++i) { - if (i) print_sub_sep(); - (*array_)[i].dump(out, indent, level + 1, to_json); - } - print_indent(level); - out << "]"; - } else if (object_) { - out << "{"; - print_indent(level + 1); - for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { - if (it != begin) print_sub_sep(); - if (it->first.is_string()) { - dump_string(it->first, out, string_quote); - } else { - out << string_quote << it->first.dump() << string_quote; - } - out << ": "; - it->second.dump(out, indent, level + 1, to_json); - } - print_indent(level); - out << "}"; - } else if (callable_) { - throw std::runtime_error("Cannot dump callable to JSON"); - } else if (is_boolean() && !to_json) { - out << (this->to_bool() ? "True" : "False"); - } else if (is_string() && !to_json) { - dump_string(primitive_, out, string_quote); - } else { - out << primitive_.dump(); - } - } - -public: - Value() {} - Value(const bool& v) : primitive_(v) {} - Value(const int64_t & v) : primitive_(v) {} - Value(const double& v) : primitive_(v) {} - Value(const std::nullptr_t &) {} - Value(const std::string & v) : primitive_(v) {} - Value(const char * v) : primitive_(std::string(v)) {} - - Value(const json & v) { - if (v.is_object()) { - auto object = std::make_shared(); - for (auto it = v.begin(); it != v.end(); ++it) { - (*object)[it.key()] = it.value(); - } - object_ = std::move(object); - } else if (v.is_array()) { - auto array = std::make_shared(); - for (const auto& item : v) { - array->push_back(Value(item)); - } - array_ = array; - } else { - primitive_ = v; - } - } - - std::vector keys() { - if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - std::vector res; - for (const auto& item : *object_) { - res.push_back(item.first); - } - return res; - } - - size_t size() const { - if (is_object()) return object_->size(); - if (is_array()) return array_->size(); - if (is_string()) return primitive_.get().length(); - throw std::runtime_error("Value is not an array or object: " + dump()); - } - - static Value array(const std::vector values = {}) { - auto array = std::make_shared(); - for (const auto& item : values) { - array->push_back(item); - } - return Value(array); - } - static Value object(const std::shared_ptr object = std::make_shared()) { - return Value(object); - } - static Value callable(const CallableType & callable) { - return Value(std::make_shared(callable)); - } - - void insert(size_t index, const Value& v) { - if (!array_) - throw std::runtime_error("Value is not an array: " + dump()); - array_->insert(array_->begin() + index, v); - } - void push_back(const Value& v) { - if (!array_) - throw std::runtime_error("Value is not an array: " + dump()); - array_->push_back(v); - } - Value pop(const Value& index) { - if (is_array()) { - if (array_->empty()) - throw std::runtime_error("pop from empty list"); - if (index.is_null()) { - auto ret = array_->back(); - array_->pop_back(); - return ret; - } else if (!index.is_number_integer()) { - throw std::runtime_error("pop index must be an integer: " + index.dump()); - } else { - auto i = index.get(); - if (i < 0 || i >= static_cast(array_->size())) - throw std::runtime_error("pop index out of range: " + index.dump()); - auto it = array_->begin() + (i < 0 ? array_->size() + i : i); - auto ret = *it; - array_->erase(it); - return ret; - } - } else if (is_object()) { - if (!index.is_hashable()) - throw std::runtime_error("Unashable type: " + index.dump()); - auto it = object_->find(index.primitive_); - if (it == object_->end()) - throw std::runtime_error("Key not found: " + index.dump()); - auto ret = it->second; - object_->erase(it); - return ret; - } else { - throw std::runtime_error("Value is not an array or object: " + dump()); - } - } - Value get(const Value& key) { - if (array_) { - if (!key.is_number_integer()) { - return Value(); - } - auto index = key.get(); - return array_->at(index < 0 ? array_->size() + index : index); - } else if (object_) { - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); - auto it = object_->find(key.primitive_); - if (it == object_->end()) return Value(); - return it->second; - } - return Value(); - } - void set(const Value& key, const Value& value) { - if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); - (*object_)[key.primitive_] = value; - } - Value call(const std::shared_ptr & context, ArgumentsValue & args) const { - if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); - return (*callable_)(context, args); - } - - bool is_object() const { return !!object_; } - bool is_array() const { return !!array_; } - bool is_callable() const { return !!callable_; } - bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } - bool is_boolean() const { return primitive_.is_boolean(); } - bool is_number_integer() const { return primitive_.is_number_integer(); } - bool is_number_float() const { return primitive_.is_number_float(); } - bool is_number() const { return primitive_.is_number(); } - bool is_string() const { return primitive_.is_string(); } - bool is_iterable() const { return is_array() || is_object() || is_string(); } - - bool is_primitive() const { return !array_ && !object_ && !callable_; } - bool is_hashable() const { return is_primitive(); } - - bool empty() const { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (is_string()) return primitive_.empty(); - if (is_array()) return array_->empty(); - if (is_object()) return object_->empty(); - return false; - } - - void for_each(const std::function & callback) const { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (array_) { - for (auto& item : *array_) { - callback(item); - } - } else if (object_) { - for (auto & item : *object_) { - Value key(item.first); - callback(key); - } - } else if (is_string()) { - for (char c : primitive_.get()) { - auto val = Value(std::string(1, c)); - callback(val); - } - } else { - throw std::runtime_error("Value is not iterable: " + dump()); - } - } - - bool to_bool() const { - if (is_null()) return false; - if (is_boolean()) return get(); - if (is_number()) return get() != 0; - if (is_string()) return !get().empty(); - if (is_array()) return !empty(); - return true; - } - - int64_t to_int() const { - if (is_null()) return 0; - if (is_boolean()) return get() ? 1 : 0; - if (is_number()) return static_cast(get()); - if (is_string()) { - try { - return std::stol(get()); - } catch (const std::exception &) { - return 0; - } - } - return 0; - } - - bool operator<(const Value & other) const { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (is_number() && other.is_number()) return get() < other.get(); - if (is_string() && other.is_string()) return get() < other.get(); - throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); - } - bool operator>=(const Value & other) const { return !(*this < other); } - - bool operator>(const Value & other) const { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (is_number() && other.is_number()) return get() > other.get(); - if (is_string() && other.is_string()) return get() > other.get(); - throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); - } - bool operator<=(const Value & other) const { return !(*this > other); } - - bool operator==(const Value & other) const { - if (callable_ || other.callable_) { - if (callable_.get() != other.callable_.get()) return false; - } - if (array_) { - if (!other.array_) return false; - if (array_->size() != other.array_->size()) return false; - for (size_t i = 0; i < array_->size(); ++i) { - if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; - } - return true; - } else if (object_) { - if (!other.object_) return false; - if (object_->size() != other.object_->size()) return false; - for (const auto& item : *object_) { - if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; - } - return true; - } else { - return primitive_ == other.primitive_; - } - } - bool operator!=(const Value & other) const { return !(*this == other); } - - bool contains(const char * key) const { return contains(std::string(key)); } - bool contains(const std::string & key) const { - if (array_) { - return false; - } else if (object_) { - return object_->find(key) != object_->end(); - } else { - throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); - } - } - bool contains(const Value & value) const { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (array_) { - for (const auto& item : *array_) { - if (item.to_bool() && item == value) return true; - } - return false; - } else if (object_) { - if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); - return object_->find(value.primitive_) != object_->end(); - } else { - throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); - } - } - void erase(size_t index) { - if (!array_) throw std::runtime_error("Value is not an array: " + dump()); - array_->erase(array_->begin() + index); - } - void erase(const std::string & key) { - if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - object_->erase(key); - } - const Value& at(const Value & index) const { - return const_cast(this)->at(index); - } - Value& at(const Value & index) { - if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); - if (is_array()) return array_->at(index.get()); - if (is_object()) return object_->at(index.primitive_); - throw std::runtime_error("Value is not an array or object: " + dump()); - } - const Value& at(size_t index) const { - return const_cast(this)->at(index); - } - Value& at(size_t index) { - if (is_null()) - throw std::runtime_error("Undefined value or reference"); - if (is_array()) return array_->at(index); - if (is_object()) return object_->at(index); - throw std::runtime_error("Value is not an array or object: " + dump()); - } - - template - T get(const std::string & key, T default_value) const { - if (!contains(key)) return default_value; - return at(key).get(); - } - - template - T get() const { - if (is_primitive()) return primitive_.get(); - throw std::runtime_error("get not defined for this value type: " + dump()); - } - - std::string dump(int indent=-1, bool to_json=false) const { - std::ostringstream out; - dump(out, indent, 0, to_json); - return out.str(); - } - - Value operator-() const { - if (is_number_integer()) - return -get(); - else - return -get(); - } - std::string to_str() const { - if (is_string()) return get(); - if (is_number_integer()) return std::to_string(get()); - if (is_number_float()) return std::to_string(get()); - if (is_boolean()) return get() ? "True" : "False"; - if (is_null()) return "None"; - return dump(); - } - Value operator+(const Value& rhs) const { - if (is_string() || rhs.is_string()) { - return to_str() + rhs.to_str(); - } else if (is_number_integer() && rhs.is_number_integer()) { - return get() + rhs.get(); - } else if (is_array() && rhs.is_array()) { - auto res = Value::array(); - for (const auto& item : *array_) res.push_back(item); - for (const auto& item : *rhs.array_) res.push_back(item); - return res; - } else { - return get() + rhs.get(); - } - } - Value operator-(const Value& rhs) const { - if (is_number_integer() && rhs.is_number_integer()) - return get() - rhs.get(); - else - return get() - rhs.get(); - } - Value operator*(const Value& rhs) const { - if (is_string() && rhs.is_number_integer()) { - std::ostringstream out; - for (int64_t i = 0, n = rhs.get(); i < n; ++i) { - out << to_str(); - } - return out.str(); - } - else if (is_number_integer() && rhs.is_number_integer()) - return get() * rhs.get(); - else - return get() * rhs.get(); - } - Value operator/(const Value& rhs) const { - if (is_number_integer() && rhs.is_number_integer()) - return get() / rhs.get(); - else - return get() / rhs.get(); - } - Value operator%(const Value& rhs) const { - return get() % rhs.get(); - } -}; - -struct ArgumentsValue { - std::vector args; - std::vector> kwargs; - - bool has_named(const std::string & name) { - for (const auto & p : kwargs) { - if (p.first == name) return true; - } - return false; - } - - Value get_named(const std::string & name) { - for (const auto & [key, value] : kwargs) { - if (key == name) return value; - } - return Value(); - } - - bool empty() { - return args.empty() && kwargs.empty(); - } - - void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { - if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { - std::ostringstream out; - out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; - throw std::runtime_error(out.str()); - } - } -}; - -template <> -inline json Value::get() const { - if (is_primitive()) return primitive_; - if (is_null()) return json(); - if (array_) { - std::vector res; - for (const auto& item : *array_) { - res.push_back(item.get()); - } - return res; - } - if (object_) { - json res = json::object(); - for (const auto& [key, value] : *object_) { - if (key.is_string()) { - res[key.get()] = value.get(); - } else if (key.is_primitive()) { - res[key.dump()] = value.get(); - } else { - throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); - } - } - if (is_callable()) { - res["__callable__"] = true; - } - return res; - } - throw std::runtime_error("get not defined for this value type: " + dump()); -} - -} // namespace minja - -namespace std { - template <> - struct hash { - size_t operator()(const minja::Value & v) const { - if (!v.is_hashable()) - throw std::runtime_error("Unsupported type for hashing: " + v.dump()); - return std::hash()(v.get()); - } - }; -} // namespace std - -namespace minja { - -static std::string error_location_suffix(const std::string & source, size_t pos) { - auto get_line = [&](size_t line) { - auto start = source.begin(); - for (size_t i = 1; i < line; ++i) { - start = std::find(start, source.end(), '\n') + 1; - } - auto end = std::find(start, source.end(), '\n'); - return std::string(start, end); - }; - auto start = source.begin(); - auto end = source.end(); - auto it = start + pos; - auto line = std::count(start, it, '\n') + 1; - auto max_line = std::count(start, end, '\n') + 1; - auto col = pos - std::string(start, it).rfind('\n'); - std::ostringstream out; - out << " at row " << line << ", column " << col << ":\n"; - if (line > 1) out << get_line(line - 1) << "\n"; - out << get_line(line) << "\n"; - out << std::string(col - 1, ' ') << "^\n"; - if (line < max_line) out << get_line(line + 1) << "\n"; - - return out.str(); -} - -class Context : public std::enable_shared_from_this { - protected: - Value values_; - std::shared_ptr parent_; - public: - Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { - if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); - } - virtual ~Context() {} - - static std::shared_ptr builtins(); - static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); - - std::vector keys() { - return values_.keys(); - } - virtual Value get(const Value & key) { - if (values_.contains(key)) return values_.at(key); - if (parent_) return parent_->get(key); - return Value(); - } - virtual Value & at(const Value & key) { - if (values_.contains(key)) return values_.at(key); - if (parent_) return parent_->at(key); - throw std::runtime_error("Undefined variable: " + key.dump()); - } - virtual bool contains(const Value & key) { - if (values_.contains(key)) return true; - if (parent_) return parent_->contains(key); - return false; - } - virtual void set(const Value & key, const Value & value) { - values_.set(key, value); - } -}; - -struct Location { - std::shared_ptr source; - size_t pos; -}; - -class Expression { -protected: - virtual Value do_evaluate(const std::shared_ptr & context) const = 0; -public: - using Parameters = std::vector>>; - - Location location; - - Expression(const Location & location) : location(location) {} - virtual ~Expression() = default; - - Value evaluate(const std::shared_ptr & context) const { - try { - return do_evaluate(context); - } catch (const std::exception & e) { - std::ostringstream out; - out << e.what(); - if (location.source) out << error_location_suffix(*location.source, location.pos); - throw std::runtime_error(out.str()); - } - } -}; - -class VariableExpr : public Expression { - std::string name; -public: - VariableExpr(const Location & location, const std::string& n) - : Expression(location), name(n) {} - std::string get_name() const { return name; } - Value do_evaluate(const std::shared_ptr & context) const override { - if (!context->contains(name)) { - return Value(); - } - return context->at(name); - } -}; - -static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { - if (var_names.size() == 1) { - Value name(var_names[0]); - context->set(name, item); - } else { - if (!item.is_array() || item.size() != var_names.size()) { - throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); - } - for (size_t i = 0; i < var_names.size(); ++i) { - context->set(var_names[i], item.at(i)); - } - } -} - -enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; - -class TemplateToken { -public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; - - static std::string typeToString(Type t) { - switch (t) { - case Type::Text: return "text"; - case Type::Expression: return "expression"; - case Type::If: return "if"; - case Type::Else: return "else"; - case Type::Elif: return "elif"; - case Type::EndIf: return "endif"; - case Type::For: return "for"; - case Type::EndFor: return "endfor"; - case Type::Set: return "set"; - case Type::EndSet: return "endset"; - case Type::Comment: return "comment"; - case Type::Macro: return "macro"; - case Type::EndMacro: return "endmacro"; - case Type::Filter: return "filter"; - case Type::EndFilter: return "endfilter"; - case Type::Generation: return "generation"; - case Type::EndGeneration: return "endgeneration"; - case Type::Break: return "break"; - case Type::Continue: return "continue"; - } - return "Unknown"; - } - - TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} - virtual ~TemplateToken() = default; - - Type type; - Location location; - SpaceHandling pre_space = SpaceHandling::Keep; - SpaceHandling post_space = SpaceHandling::Keep; -}; - -struct TextTemplateToken : public TemplateToken { - std::string text; - TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} -}; - -struct ExpressionTemplateToken : public TemplateToken { - std::shared_ptr expr; - ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} -}; - -struct IfTemplateToken : public TemplateToken { - std::shared_ptr condition; - IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} -}; - -struct ElifTemplateToken : public TemplateToken { - std::shared_ptr condition; - ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} -}; - -struct ElseTemplateToken : public TemplateToken { - ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} -}; - -struct EndIfTemplateToken : public TemplateToken { - EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} -}; - -struct MacroTemplateToken : public TemplateToken { - std::shared_ptr name; - Expression::Parameters params; - MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) - : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} -}; - -struct EndMacroTemplateToken : public TemplateToken { - EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} -}; - -struct FilterTemplateToken : public TemplateToken { - std::shared_ptr filter; - FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) - : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} -}; - -struct EndFilterTemplateToken : public TemplateToken { - EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} -}; - -struct ForTemplateToken : public TemplateToken { - std::vector var_names; - std::shared_ptr iterable; - std::shared_ptr condition; - bool recursive; - ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, - std::shared_ptr && c, bool r) - : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} -}; - -struct EndForTemplateToken : public TemplateToken { - EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} -}; - -struct GenerationTemplateToken : public TemplateToken { - GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} -}; - -struct EndGenerationTemplateToken : public TemplateToken { - EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} -}; - -struct SetTemplateToken : public TemplateToken { - std::string ns; - std::vector var_names; - std::shared_ptr value; - SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} -}; - -struct EndSetTemplateToken : public TemplateToken { - EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} -}; - -struct CommentTemplateToken : public TemplateToken { - std::string text; - CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} -}; - -enum class LoopControlType { Break, Continue }; - -class LoopControlException : public std::runtime_error { -public: - LoopControlType control_type; - LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} - LoopControlException(LoopControlType control_type) - : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")), - control_type(control_type) {} -}; - -struct LoopControlTemplateToken : public TemplateToken { - LoopControlType control_type; - LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} -}; - -class TemplateNode { - Location location_; -protected: - virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; - -public: - TemplateNode(const Location & location) : location_(location) {} - void render(std::ostringstream & out, const std::shared_ptr & context) const { - try { - do_render(out, context); - } catch (const LoopControlException & e) { - // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop. - std::ostringstream err; - err << e.what(); - if (location_.source) err << error_location_suffix(*location_.source, location_.pos); - throw LoopControlException(err.str(), e.control_type); - } catch (const std::exception & e) { - std::ostringstream err; - err << e.what(); - if (location_.source) err << error_location_suffix(*location_.source, location_.pos); - throw std::runtime_error(err.str()); - } - } - const Location & location() const { return location_; } - virtual ~TemplateNode() = default; - std::string render(const std::shared_ptr & context) const { - std::ostringstream out; - render(out, context); - return out.str(); - } -}; - -class SequenceNode : public TemplateNode { - std::vector> children; -public: - SequenceNode(const Location & location, std::vector> && c) - : TemplateNode(location), children(std::move(c)) {} - void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { - for (const auto& child : children) child->render(out, context); - } -}; - -class TextNode : public TemplateNode { - std::string text; -public: - TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} - void do_render(std::ostringstream & out, const std::shared_ptr &) const override { - out << text; - } -}; - -class ExpressionNode : public TemplateNode { - std::shared_ptr expr; -public: - ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} - void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { - if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); - auto result = expr->evaluate(context); - if (result.is_string()) { - out << result.get(); - } else if (result.is_boolean()) { - out << (result.get() ? "True" : "False"); - } else if (!result.is_null()) { - out << result.dump(); - } - } -}; - -class IfNode : public TemplateNode { - std::vector, std::shared_ptr>> cascade; -public: - IfNode(const Location & location, std::vector, std::shared_ptr>> && c) - : TemplateNode(location), cascade(std::move(c)) {} - void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { - for (const auto& branch : cascade) { - auto enter_branch = true; - if (branch.first) { - enter_branch = branch.first->evaluate(context).to_bool(); - } - if (enter_branch) { - if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); - branch.second->render(out, context); - return; - } - } - } -}; - -class LoopControlNode : public TemplateNode { - LoopControlType control_type_; - public: - LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} - void do_render(std::ostringstream &, const std::shared_ptr &) const override { - throw LoopControlException(control_type_); - } -}; - -class ForNode : public TemplateNode { - std::vector var_names; - std::shared_ptr iterable; - std::shared_ptr condition; - std::shared_ptr body; - bool recursive; - std::shared_ptr else_body; -public: - ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, - std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) - : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} - - void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { - // https://jinja.palletsprojects.com/en/3.0.x/templates/#for - if (!iterable) throw std::runtime_error("ForNode.iterable is null"); - if (!body) throw std::runtime_error("ForNode.body is null"); - - auto iterable_value = iterable->evaluate(context); - Value::CallableType loop_function; - - std::function visit = [&](Value& iter) { - auto filtered_items = Value::array(); - if (!iter.is_null()) { - if (!iterable_value.is_iterable()) { - throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); - } - iterable_value.for_each([&](Value & item) { - destructuring_assign(var_names, context, item); - if (!condition || condition->evaluate(context).to_bool()) { - filtered_items.push_back(item); - } - }); - } - if (filtered_items.empty()) { - if (else_body) { - else_body->render(out, context); - } - } else { - auto loop = recursive ? Value::callable(loop_function) : Value::object(); - loop.set("length", (int64_t) filtered_items.size()); - - size_t cycle_index = 0; - loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { - if (args.args.empty() || !args.kwargs.empty()) { - throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); - } - auto item = args.args[cycle_index]; - cycle_index = (cycle_index + 1) % args.args.size(); - return item; - })); - auto loop_context = Context::make(Value::object(), context); - loop_context->set("loop", loop); - for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { - auto & item = filtered_items.at(i); - destructuring_assign(var_names, loop_context, item); - loop.set("index", (int64_t) i + 1); - loop.set("index0", (int64_t) i); - loop.set("revindex", (int64_t) (n - i)); - loop.set("revindex0", (int64_t) (n - i - 1)); - loop.set("length", (int64_t) n); - loop.set("first", i == 0); - loop.set("last", i == (n - 1)); - loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); - loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); - try { - body->render(out, loop_context); - } catch (const LoopControlException & e) { - if (e.control_type == LoopControlType::Break) break; - if (e.control_type == LoopControlType::Continue) continue; - } - } - } - }; - - if (recursive) { - loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { - if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { - throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); - } - auto & items = args.args[0]; - visit(items); - return Value(); - }; - } - - visit(iterable_value); - } -}; - -class MacroNode : public TemplateNode { - std::shared_ptr name; - Expression::Parameters params; - std::shared_ptr body; - std::unordered_map named_param_positions; -public: - MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) - : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { - for (size_t i = 0; i < params.size(); ++i) { - const auto & name = params[i].first; - if (!name.empty()) { - named_param_positions[name] = i; - } - } - } - void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { - if (!name) throw std::runtime_error("MacroNode.name is null"); - if (!body) throw std::runtime_error("MacroNode.body is null"); - auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { - auto call_context = macro_context; - std::vector param_set(params.size(), false); - for (size_t i = 0, n = args.args.size(); i < n; i++) { - auto & arg = args.args[i]; - if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); - param_set[i] = true; - auto & param_name = params[i].first; - call_context->set(param_name, arg); - } - for (auto & [arg_name, value] : args.kwargs) { - auto it = named_param_positions.find(arg_name); - if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - - call_context->set(arg_name, value); - param_set[it->second] = true; - } - // Set default values for parameters that were not passed - for (size_t i = 0, n = params.size(); i < n; i++) { - if (!param_set[i] && params[i].second != nullptr) { - auto val = params[i].second->evaluate(context); - call_context->set(params[i].first, val); - } - } - return body->render(call_context); - }); - macro_context->set(name->get_name(), callable); - } -}; - -class FilterNode : public TemplateNode { - std::shared_ptr filter; - std::shared_ptr body; - -public: - FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) - : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} - - void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { - if (!filter) throw std::runtime_error("FilterNode.filter is null"); - if (!body) throw std::runtime_error("FilterNode.body is null"); - auto filter_value = filter->evaluate(context); - if (!filter_value.is_callable()) { - throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); - } - std::string rendered_body = body->render(context); - - ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; - auto result = filter_value.call(context, filter_args); - out << result.to_str(); - } -}; - -class SetNode : public TemplateNode { - std::string ns; - std::vector var_names; - std::shared_ptr value; -public: - SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} - void do_render(std::ostringstream &, const std::shared_ptr & context) const override { - if (!value) throw std::runtime_error("SetNode.value is null"); - if (!ns.empty()) { - if (var_names.size() != 1) { - throw std::runtime_error("Namespaced set only supports a single variable name"); - } - auto & name = var_names[0]; - auto ns_value = context->get(ns); - if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); - ns_value.set(name, this->value->evaluate(context)); - } else { - auto val = value->evaluate(context); - destructuring_assign(var_names, context, val); - } - } -}; - -class SetTemplateNode : public TemplateNode { - std::string name; - std::shared_ptr template_value; -public: - SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) - : TemplateNode(location), name(name), template_value(std::move(tv)) {} - void do_render(std::ostringstream &, const std::shared_ptr & context) const override { - if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); - Value value { template_value->render(context) }; - context->set(name, value); - } -}; - -class IfExpr : public Expression { - std::shared_ptr condition; - std::shared_ptr then_expr; - std::shared_ptr else_expr; -public: - IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) - : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!condition) throw std::runtime_error("IfExpr.condition is null"); - if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); - if (condition->evaluate(context).to_bool()) { - return then_expr->evaluate(context); - } - if (else_expr) { - return else_expr->evaluate(context); - } - return nullptr; - } -}; - -class LiteralExpr : public Expression { - Value value; -public: - LiteralExpr(const Location & location, const Value& v) - : Expression(location), value(v) {} - Value do_evaluate(const std::shared_ptr &) const override { return value; } -}; - -class ArrayExpr : public Expression { - std::vector> elements; -public: - ArrayExpr(const Location & location, std::vector> && e) - : Expression(location), elements(std::move(e)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - auto result = Value::array(); - for (const auto& e : elements) { - if (!e) throw std::runtime_error("Array element is null"); - result.push_back(e->evaluate(context)); - } - return result; - } -}; - -class DictExpr : public Expression { - std::vector, std::shared_ptr>> elements; -public: - DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) - : Expression(location), elements(std::move(e)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - auto result = Value::object(); - for (const auto& [key, value] : elements) { - if (!key) throw std::runtime_error("Dict key is null"); - if (!value) throw std::runtime_error("Dict value is null"); - result.set(key->evaluate(context), value->evaluate(context)); - } - return result; - } -}; - -class SliceExpr : public Expression { -public: - std::shared_ptr start, end; - SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) - : Expression(location), start(std::move(s)), end(std::move(e)) {} - Value do_evaluate(const std::shared_ptr &) const override { - throw std::runtime_error("SliceExpr not implemented"); - } -}; - -class SubscriptExpr : public Expression { - std::shared_ptr base; - std::shared_ptr index; -public: - SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) - : Expression(location), base(std::move(b)), index(std::move(i)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!base) throw std::runtime_error("SubscriptExpr.base is null"); - if (!index) throw std::runtime_error("SubscriptExpr.index is null"); - auto target_value = base->evaluate(context); - if (auto slice = dynamic_cast(index.get())) { - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); - if (target_value.is_string()) { - std::string s = target_value.get(); - if (start < 0) start = s.size() + start; - if (end < 0) end = s.size() + end; - return s.substr(start, end - start); - } else if (target_value.is_array()) { - if (start < 0) start = target_value.size() + start; - if (end < 0) end = target_value.size() + end; - auto result = Value::array(); - for (auto i = start; i < end; ++i) { - result.push_back(target_value.at(i)); - } - return result; - } else { - throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); - } - } else { - auto index_value = index->evaluate(context); - if (target_value.is_null()) { - if (auto t = dynamic_cast(base.get())) { - throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); - } - throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); - } - return target_value.get(index_value); - } - } -}; - -class UnaryOpExpr : public Expression { -public: - enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; - std::shared_ptr expr; - Op op; - UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) - : Expression(location), expr(std::move(e)), op(o) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); - auto e = expr->evaluate(context); - switch (op) { - case Op::Plus: return e; - case Op::Minus: return -e; - case Op::LogicalNot: return !e.to_bool(); - case Op::Expansion: - case Op::ExpansionDict: - throw std::runtime_error("Expansion operator is only supported in function calls and collections"); - - } - throw std::runtime_error("Unknown unary operator"); - } -}; - -class BinaryOpExpr : public Expression { -public: - enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; -private: - std::shared_ptr left; - std::shared_ptr right; - Op op; -public: - BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) - : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); - if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); - auto l = left->evaluate(context); - - auto do_eval = [&](const Value & l) -> Value { - if (op == Op::Is || op == Op::IsNot) { - auto t = dynamic_cast(right.get()); - if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); - - auto eval = [&]() { - const auto & name = t->get_name(); - if (name == "none") return l.is_null(); - if (name == "boolean") return l.is_boolean(); - if (name == "integer") return l.is_number_integer(); - if (name == "float") return l.is_number_float(); - if (name == "number") return l.is_number(); - if (name == "string") return l.is_string(); - if (name == "mapping") return l.is_object(); - if (name == "iterable") return l.is_iterable(); - if (name == "sequence") return l.is_array(); - if (name == "defined") return !l.is_null(); - throw std::runtime_error("Unknown type for 'is' operator: " + name); - }; - auto value = eval(); - return Value(op == Op::Is ? value : !value); - } - - if (op == Op::And) { - if (!l.to_bool()) return Value(false); - return right->evaluate(context).to_bool(); - } else if (op == Op::Or) { - if (l.to_bool()) return l; - return right->evaluate(context); - } - - auto r = right->evaluate(context); - switch (op) { - case Op::StrConcat: return l.to_str() + r.to_str(); - case Op::Add: return l + r; - case Op::Sub: return l - r; - case Op::Mul: return l * r; - case Op::Div: return l / r; - case Op::MulMul: return std::pow(l.get(), r.get()); - case Op::DivDiv: return l.get() / r.get(); - case Op::Mod: return l.get() % r.get(); - case Op::Eq: return l == r; - case Op::Ne: return l != r; - case Op::Lt: return l < r; - case Op::Gt: return l > r; - case Op::Le: return l <= r; - case Op::Ge: return l >= r; - case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); - case Op::NotIn: return !(r.is_array() && r.contains(l)); - default: break; - } - throw std::runtime_error("Unknown binary operator"); - }; - - if (l.is_callable()) { - return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { - auto ll = l.call(context, args); - return do_eval(ll); //args[0].second); - }); - } else { - return do_eval(l); - } - } -}; - -struct ArgumentsExpression { - std::vector> args; - std::vector>> kwargs; - - ArgumentsValue evaluate(const std::shared_ptr & context) const { - ArgumentsValue vargs; - for (const auto& arg : this->args) { - if (auto un_expr = std::dynamic_pointer_cast(arg)) { - if (un_expr->op == UnaryOpExpr::Op::Expansion) { - auto array = un_expr->expr->evaluate(context); - if (!array.is_array()) { - throw std::runtime_error("Expansion operator only supported on arrays"); - } - array.for_each([&](Value & value) { - vargs.args.push_back(value); - }); - continue; - } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { - auto dict = un_expr->expr->evaluate(context); - if (!dict.is_object()) { - throw std::runtime_error("ExpansionDict operator only supported on objects"); - } - dict.for_each([&](const Value & key) { - vargs.kwargs.push_back({key.get(), dict.at(key)}); - }); - continue; - } - } - vargs.args.push_back(arg->evaluate(context)); - } - for (const auto& [name, value] : this->kwargs) { - vargs.kwargs.push_back({name, value->evaluate(context)}); - } - return vargs; - } -}; - -static std::string strip(const std::string & s) { - auto start = s.find_first_not_of(" \t\n\r"); - if (start == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\n\r"); - return s.substr(start, end - start + 1); -} - -static std::string capitalize(const std::string & s) { - if (s.empty()) return s; - auto result = s; - result[0] = std::toupper(result[0]); - return result; -} - -static std::string html_escape(const std::string & s) { - std::string result; - result.reserve(s.size()); - for (const auto & c : s) { - switch (c) { - case '&': result += "&"; break; - case '<': result += "<"; break; - case '>': result += ">"; break; - case '"': result += """; break; - case '\'': result += "'"; break; - default: result += c; break; - } - } - return result; -} - -class MethodCallExpr : public Expression { - std::shared_ptr object; - std::shared_ptr method; - ArgumentsExpression args; -public: - MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!object) throw std::runtime_error("MethodCallExpr.object is null"); - if (!method) throw std::runtime_error("MethodCallExpr.method is null"); - auto obj = object->evaluate(context); - auto vargs = args.evaluate(context); - if (obj.is_null()) { - throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); - } - if (obj.is_array()) { - if (method->get_name() == "append") { - vargs.expectArgs("append method", {1, 1}, {0, 0}); - obj.push_back(vargs.args[0]); - return Value(); - } else if (method->get_name() == "pop") { - vargs.expectArgs("pop method", {0, 1}, {0, 0}); - return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); - } else if (method->get_name() == "insert") { - vargs.expectArgs("insert method", {2, 2}, {0, 0}); - auto index = vargs.args[0].get(); - if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); - obj.insert(index, vargs.args[1]); - return Value(); - } - } else if (obj.is_object()) { - if (method->get_name() == "items") { - vargs.expectArgs("items method", {0, 0}, {0, 0}); - auto result = Value::array(); - for (const auto& key : obj.keys()) { - result.push_back(Value::array({key, obj.at(key)})); - } - return result; - } else if (method->get_name() == "pop") { - vargs.expectArgs("pop method", {1, 1}, {0, 0}); - return obj.pop(vargs.args[0]); - } else if (method->get_name() == "get") { - vargs.expectArgs("get method", {1, 2}, {0, 0}); - auto key = vargs.args[0]; - if (vargs.args.size() == 1) { - return obj.contains(key) ? obj.at(key) : Value(); - } else { - return obj.contains(key) ? obj.at(key) : vargs.args[1]; - } - } else if (obj.contains(method->get_name())) { - auto callable = obj.at(method->get_name()); - if (!callable.is_callable()) { - throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); - } - return callable.call(context, vargs); - } - } else if (obj.is_string()) { - auto str = obj.get(); - if (method->get_name() == "strip") { - vargs.expectArgs("strip method", {0, 0}, {0, 0}); - return Value(strip(str)); - } else if (method->get_name() == "capitalize") { - vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); - return Value(capitalize(str)); - } else if (method->get_name() == "endswith") { - vargs.expectArgs("endswith method", {1, 1}, {0, 0}); - auto suffix = vargs.args[0].get(); - return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); - } else if (method->get_name() == "title") { - vargs.expectArgs("title method", {0, 0}, {0, 0}); - auto res = str; - for (size_t i = 0, n = res.size(); i < n; ++i) { - if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); - else res[i] = std::tolower(res[i]); - } - return res; - } - } - throw std::runtime_error("Unknown method: " + method->get_name()); - } -}; - -class CallExpr : public Expression { -public: - std::shared_ptr object; - ArgumentsExpression args; - CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), args(std::move(a)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - if (!object) throw std::runtime_error("CallExpr.object is null"); - auto obj = object->evaluate(context); - if (!obj.is_callable()) { - throw std::runtime_error("Object is not callable: " + obj.dump(2)); - } - auto vargs = args.evaluate(context); - return obj.call(context, vargs); - } -}; - -class FilterExpr : public Expression { - std::vector> parts; -public: - FilterExpr(const Location & location, std::vector> && p) - : Expression(location), parts(std::move(p)) {} - Value do_evaluate(const std::shared_ptr & context) const override { - Value result; - bool first = true; - for (const auto& part : parts) { - if (!part) throw std::runtime_error("FilterExpr.part is null"); - if (first) { - first = false; - result = part->evaluate(context); - } else { - if (auto ce = dynamic_cast(part.get())) { - auto target = ce->object->evaluate(context); - ArgumentsValue args = ce->args.evaluate(context); - args.args.insert(args.args.begin(), result); - result = target.call(context, args); - } else { - auto callable = part->evaluate(context); - ArgumentsValue args; - args.args.insert(args.args.begin(), result); - result = callable.call(context, args); - } - } - } - return result; - } - - void prepend(std::shared_ptr && e) { - parts.insert(parts.begin(), std::move(e)); - } -}; - -class Parser { -private: - using CharIterator = std::string::const_iterator; - - std::shared_ptr template_str; - CharIterator start, end, it; - Options options; - - Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { - if (!template_str) throw std::runtime_error("Template string is null"); - start = it = this->template_str->begin(); - end = this->template_str->end(); - } - - bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { - if (space_handling == SpaceHandling::Strip) { - while (it != end && std::isspace(*it)) ++it; - } - return true; - } - - std::unique_ptr parseString() { - auto doParse = [&](char quote) -> std::unique_ptr { - if (it == end || *it != quote) return nullptr; - std::string result; - bool escape = false; - for (++it; it != end; ++it) { - if (escape) { - escape = false; - switch (*it) { - case 'n': result += '\n'; break; - case 'r': result += '\r'; break; - case 't': result += '\t'; break; - case 'b': result += '\b'; break; - case 'f': result += '\f'; break; - case '\\': result += '\\'; break; - default: - if (*it == quote) { - result += quote; - } else { - result += *it; - } - break; - } - } else if (*it == '\\') { - escape = true; - } else if (*it == quote) { - ++it; - return std::make_unique(std::move(result)); - } else { - result += *it; - } - } - return nullptr; - }; - - consumeSpaces(); - if (it == end) return nullptr; - if (*it == '"') return doParse('"'); - if (*it == '\'') return doParse('\''); - return nullptr; - } - - json parseNumber(CharIterator& it, const CharIterator& end) { - auto before = it; - consumeSpaces(); - auto start = it; - bool hasDecimal = false; - bool hasExponent = false; - - if (it != end && (*it == '-' || *it == '+')) ++it; - - while (it != end) { - if (std::isdigit(*it)) { - ++it; - } else if (*it == '.') { - if (hasDecimal) throw std::runtime_error("Multiple decimal points"); - hasDecimal = true; - ++it; - } else if (it != start && (*it == 'e' || *it == 'E')) { - if (hasExponent) throw std::runtime_error("Multiple exponents"); - hasExponent = true; - ++it; - } else { - break; - } - } - if (start == it) { - it = before; - return json(); // No valid characters found - } - - std::string str(start, it); - try { - return json::parse(str); - } catch (json::parse_error& e) { - throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); - return json(); - } - } - - /** integer, float, bool, string */ - std::shared_ptr parseConstant() { - auto start = it; - consumeSpaces(); - if (it == end) return nullptr; - if (*it == '"' || *it == '\'') { - auto str = parseString(); - if (str) return std::make_shared(*str); - } - static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); - auto token = consumeToken(prim_tok); - if (!token.empty()) { - if (token == "true" || token == "True") return std::make_shared(true); - if (token == "false" || token == "False") return std::make_shared(false); - if (token == "None") return std::make_shared(nullptr); - throw std::runtime_error("Unknown constant token: " + token); - } - - auto number = parseNumber(it, end); - if (!number.is_null()) return std::make_shared(number); - - it = start; - return nullptr; - } - - class expression_parsing_error : public std::runtime_error { - const CharIterator it; - public: - expression_parsing_error(const std::string & message, const CharIterator & it) - : std::runtime_error(message), it(it) {} - size_t get_pos(const CharIterator & begin) const { - return std::distance(begin, it); - } - }; - - bool peekSymbols(const std::vector & symbols) const { - for (const auto & symbol : symbols) { - if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { - return true; - } - } - return false; - } - - std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { - auto start = it; - consumeSpaces(space_handling); - std::smatch match; - if (std::regex_search(it, end, match, regex) && match.position() == 0) { - it += match[0].length(); - std::vector ret; - for (size_t i = 0, n = match.size(); i < n; ++i) { - ret.push_back(match[i].str()); - } - return ret; - } - it = start; - return {}; - } - std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { - auto start = it; - consumeSpaces(space_handling); - std::smatch match; - if (std::regex_search(it, end, match, regex) && match.position() == 0) { - it += match[0].length(); - return match[0].str(); - } - it = start; - return ""; - } - - std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { - auto start = it; - consumeSpaces(space_handling); - if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { - it += token.size(); - return token; - } - it = start; - return ""; - } - - std::shared_ptr parseExpression(bool allow_if_expr = true) { - auto left = parseLogicalOr(); - if (it == end) return left; - - if (!allow_if_expr) return left; - - static std::regex if_tok(R"(if\b)"); - if (consumeToken(if_tok).empty()) { - return left; - } - - auto location = get_location(); - auto [condition, else_expr] = parseIfExpression(); - return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); - } - - Location get_location() const { - return {template_str, (size_t) std::distance(start, it)}; - } - - std::pair, std::shared_ptr> parseIfExpression() { - auto condition = parseLogicalOr(); - if (!condition) throw std::runtime_error("Expected condition expression"); - - static std::regex else_tok(R"(else\b)"); - std::shared_ptr else_expr; - if (!consumeToken(else_tok).empty()) { - else_expr = parseExpression(); - if (!else_expr) throw std::runtime_error("Expected 'else' expression"); - } - return std::pair(std::move(condition), std::move(else_expr)); - } - - std::shared_ptr parseLogicalOr() { - auto left = parseLogicalAnd(); - if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); - - static std::regex or_tok(R"(or\b)"); - auto location = get_location(); - while (!consumeToken(or_tok).empty()) { - auto right = parseLogicalAnd(); - if (!right) throw std::runtime_error("Expected right side of 'or' expression"); - left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); - } - return left; - } - - std::shared_ptr parseLogicalNot() { - static std::regex not_tok(R"(not\b)"); - auto location = get_location(); - - if (!consumeToken(not_tok).empty()) { - auto sub = parseLogicalNot(); - if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); - return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); - } - return parseLogicalCompare(); - } - - std::shared_ptr parseLogicalAnd() { - auto left = parseLogicalNot(); - if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); - - static std::regex and_tok(R"(and\b)"); - auto location = get_location(); - while (!consumeToken(and_tok).empty()) { - auto right = parseLogicalNot(); - if (!right) throw std::runtime_error("Expected right side of 'and' expression"); - left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); - } - return left; - } - - std::shared_ptr parseLogicalCompare() { - auto left = parseStringConcat(); - if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); - - static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); - static std::regex not_tok(R"(not\b)"); - std::string op_str; - while (!(op_str = consumeToken(compare_tok)).empty()) { - auto location = get_location(); - if (op_str == "is") { - auto negated = !consumeToken(not_tok).empty(); - - auto identifier = parseIdentifier(); - if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); - - return std::make_shared( - left->location, - std::move(left), std::move(identifier), - negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); - } - auto right = parseStringConcat(); - if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); - BinaryOpExpr::Op op; - if (op_str == "==") op = BinaryOpExpr::Op::Eq; - else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; - else if (op_str == "<") op = BinaryOpExpr::Op::Lt; - else if (op_str == ">") op = BinaryOpExpr::Op::Gt; - else if (op_str == "<=") op = BinaryOpExpr::Op::Le; - else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; - else if (op_str == "in") op = BinaryOpExpr::Op::In; - else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; - else throw std::runtime_error("Unknown comparison operator: " + op_str); - left = std::make_shared(get_location(), std::move(left), std::move(right), op); - } - return left; - } - - Expression::Parameters parseParameters() { - consumeSpaces(); - if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); - - Expression::Parameters result; - - while (it != end) { - if (!consumeToken(")").empty()) { - return result; - } - auto expr = parseExpression(); - if (!expr) throw std::runtime_error("Expected expression in call args"); - - if (auto ident = dynamic_cast(expr.get())) { - if (!consumeToken("=").empty()) { - auto value = parseExpression(); - if (!value) throw std::runtime_error("Expected expression in for named arg"); - result.emplace_back(ident->get_name(), std::move(value)); - } else { - result.emplace_back(ident->get_name(), nullptr); - } - } else { - result.emplace_back(std::string(), std::move(expr)); - } - if (consumeToken(",").empty()) { - if (consumeToken(")").empty()) { - throw std::runtime_error("Expected closing parenthesis in call args"); - } - return result; - } - } - throw std::runtime_error("Expected closing parenthesis in call args"); - } - - ArgumentsExpression parseCallArgs() { - consumeSpaces(); - if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); - - ArgumentsExpression result; - - while (it != end) { - if (!consumeToken(")").empty()) { - return result; - } - auto expr = parseExpression(); - if (!expr) throw std::runtime_error("Expected expression in call args"); - - if (auto ident = dynamic_cast(expr.get())) { - if (!consumeToken("=").empty()) { - auto value = parseExpression(); - if (!value) throw std::runtime_error("Expected expression in for named arg"); - result.kwargs.emplace_back(ident->get_name(), std::move(value)); - } else { - result.args.emplace_back(std::move(expr)); - } - } else { - result.args.emplace_back(std::move(expr)); - } - if (consumeToken(",").empty()) { - if (consumeToken(")").empty()) { - throw std::runtime_error("Expected closing parenthesis in call args"); - } - return result; - } - } - throw std::runtime_error("Expected closing parenthesis in call args"); - } - - std::shared_ptr parseIdentifier() { - static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); - auto location = get_location(); - auto ident = consumeToken(ident_regex); - if (ident.empty()) - return nullptr; - return std::make_shared(location, ident); - } - - std::shared_ptr parseStringConcat() { - auto left = parseMathPow(); - if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); - - static std::regex concat_tok(R"(~(?!\}))"); - if (!consumeToken(concat_tok).empty()) { - auto right = parseLogicalAnd(); - if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); - left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); - } - return left; - } - - std::shared_ptr parseMathPow() { - auto left = parseMathPlusMinus(); - if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); - - while (!consumeToken("**").empty()) { - auto right = parseMathPlusMinus(); - if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); - left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); - } - return left; - } - - std::shared_ptr parseMathPlusMinus() { - static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); - - auto left = parseMathMulDiv(); - if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); - std::string op_str; - while (!(op_str = consumeToken(plus_minus_tok)).empty()) { - auto right = parseMathMulDiv(); - if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); - auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; - left = std::make_shared(get_location(), std::move(left), std::move(right), op); - } - return left; - } - - std::shared_ptr parseMathMulDiv() { - auto left = parseMathUnaryPlusMinus(); - if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); - - static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); - std::string op_str; - while (!(op_str = consumeToken(mul_div_tok)).empty()) { - auto right = parseMathUnaryPlusMinus(); - if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); - auto op = op_str == "*" ? BinaryOpExpr::Op::Mul - : op_str == "**" ? BinaryOpExpr::Op::MulMul - : op_str == "/" ? BinaryOpExpr::Op::Div - : op_str == "//" ? BinaryOpExpr::Op::DivDiv - : BinaryOpExpr::Op::Mod; - left = std::make_shared(get_location(), std::move(left), std::move(right), op); - } - - if (!consumeToken("|").empty()) { - auto expr = parseMathMulDiv(); - if (auto filter = dynamic_cast(expr.get())) { - filter->prepend(std::move(left)); - return expr; - } else { - std::vector> parts; - parts.emplace_back(std::move(left)); - parts.emplace_back(std::move(expr)); - return std::make_shared(get_location(), std::move(parts)); - } - } - return left; - } - - std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { - return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); - } - - std::shared_ptr parseMathUnaryPlusMinus() { - static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); - auto op_str = consumeToken(unary_plus_minus_tok); - auto expr = parseExpansion(); - if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression"); - - if (!op_str.empty()) { - auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; - return std::make_shared(get_location(), std::move(expr), op); - } - return expr; - } - - std::shared_ptr parseExpansion() { - static std::regex expansion_tok(R"(\*\*?)"); - auto op_str = consumeToken(expansion_tok); - auto expr = parseValueExpression(); - if (op_str.empty()) return expr; - if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression"); - return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); - } - - std::shared_ptr parseValueExpression() { - auto parseValue = [&]() -> std::shared_ptr { - auto location = get_location(); - auto constant = parseConstant(); - if (constant) return std::make_shared(location, *constant); - - static std::regex null_regex(R"(null\b)"); - if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); - - auto identifier = parseIdentifier(); - if (identifier) return identifier; - - auto braced = parseBracedExpressionOrArray(); - if (braced) return braced; - - auto array = parseArray(); - if (array) return array; - - auto dictionary = parseDictionary(); - if (dictionary) return dictionary; - - throw std::runtime_error("Expected value expression"); - }; - - auto value = parseValue(); - - while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { - if (!consumeToken("[").empty()) { - std::shared_ptr index; - if (!consumeToken(":").empty()) { - auto slice_end = parseExpression(); - index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); - } else { - auto slice_start = parseExpression(); - if (!consumeToken(":").empty()) { - consumeSpaces(); - if (peekSymbols({ "]" })) { - index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); - } else { - auto slice_end = parseExpression(); - index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); - } - } else { - index = std::move(slice_start); - } - } - if (!index) throw std::runtime_error("Empty index in subscript"); - if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - - value = std::make_shared(value->location, std::move(value), std::move(index)); - } else if (!consumeToken(".").empty()) { - auto identifier = parseIdentifier(); - if (!identifier) throw std::runtime_error("Expected identifier in subscript"); - - consumeSpaces(); - if (peekSymbols({ "(" })) { - auto callParams = parseCallArgs(); - value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); - } else { - auto key = std::make_shared(identifier->location, Value(identifier->get_name())); - value = std::make_shared(identifier->location, std::move(value), std::move(key)); - } - } - consumeSpaces(); - } - - if (peekSymbols({ "(" })) { - auto location = get_location(); - auto callParams = parseCallArgs(); - value = std::make_shared(location, std::move(value), std::move(callParams)); - } - return value; - } - - std::shared_ptr parseBracedExpressionOrArray() { - if (consumeToken("(").empty()) return nullptr; - - auto expr = parseExpression(); - if (!expr) throw std::runtime_error("Expected expression in braced expression"); - - if (!consumeToken(")").empty()) { - return expr; // Drop the parentheses - } - - std::vector> tuple; - tuple.emplace_back(std::move(expr)); - - while (it != end) { - if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); - auto next = parseExpression(); - if (!next) throw std::runtime_error("Expected expression in tuple"); - tuple.push_back(std::move(next)); - - if (!consumeToken(")").empty()) { - return std::make_shared(get_location(), std::move(tuple)); - } - } - throw std::runtime_error("Expected closing parenthesis"); - } - - std::shared_ptr parseArray() { - if (consumeToken("[").empty()) return nullptr; - - std::vector> elements; - if (!consumeToken("]").empty()) { - return std::make_shared(get_location(), std::move(elements)); - } - auto first_expr = parseExpression(); - if (!first_expr) throw std::runtime_error("Expected first expression in array"); - elements.push_back(std::move(first_expr)); - - while (it != end) { - if (!consumeToken(",").empty()) { - auto expr = parseExpression(); - if (!expr) throw std::runtime_error("Expected expression in array"); - elements.push_back(std::move(expr)); - } else if (!consumeToken("]").empty()) { - return std::make_shared(get_location(), std::move(elements)); - } else { - throw std::runtime_error("Expected comma or closing bracket in array"); - } - } - throw std::runtime_error("Expected closing bracket"); - } - - std::shared_ptr parseDictionary() { - if (consumeToken("{").empty()) return nullptr; - - std::vector, std::shared_ptr>> elements; - if (!consumeToken("}").empty()) { - return std::make_shared(get_location(), std::move(elements)); - } - - auto parseKeyValuePair = [&]() { - auto key = parseExpression(); - if (!key) throw std::runtime_error("Expected key in dictionary"); - if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); - auto value = parseExpression(); - if (!value) throw std::runtime_error("Expected value in dictionary"); - elements.emplace_back(std::pair(std::move(key), std::move(value))); - }; - - parseKeyValuePair(); - - while (it != end) { - if (!consumeToken(",").empty()) { - parseKeyValuePair(); - } else if (!consumeToken("}").empty()) { - return std::make_shared(get_location(), std::move(elements)); - } else { - throw std::runtime_error("Expected comma or closing brace in dictionary"); - } - } - throw std::runtime_error("Expected closing brace"); - } - - SpaceHandling parsePreSpace(const std::string& s) const { - if (s == "-") - return SpaceHandling::Strip; - return SpaceHandling::Keep; - } - - SpaceHandling parsePostSpace(const std::string& s) const { - if (s == "-") return SpaceHandling::Strip; - return SpaceHandling::Keep; - } - - using TemplateTokenVector = std::vector>; - using TemplateTokenIterator = TemplateTokenVector::const_iterator; - - std::vector parseVarNames() { - static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); - - std::vector group; - if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); - std::vector varnames; - std::istringstream iss(group[1]); - std::string varname; - while (std::getline(iss, varname, ',')) { - varnames.push_back(strip(varname)); - } - return varnames; - } - - std::runtime_error unexpected(const TemplateToken & token) const { - return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) - + error_location_suffix(*template_str, token.location.pos)); - } - std::runtime_error unterminated(const TemplateToken & token) const { - return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) - + error_location_suffix(*template_str, token.location.pos)); - } - - TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); - static std::regex expr_open_regex(R"(\{\{([-~])?)"); - static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); - static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); - static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); - static std::regex block_close_regex(R"(\s*([-~])?%\})"); - - TemplateTokenVector tokens; - std::vector group; - std::string text; - std::smatch match; - - try { - while (it != end) { - auto location = get_location(); - - if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { - auto pre_space = parsePreSpace(group[1]); - auto content = group[2]; - auto post_space = parsePostSpace(group[3]); - tokens.push_back(std::make_unique(location, pre_space, post_space, content)); - } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { - auto pre_space = parsePreSpace(group[1]); - auto expr = parseExpression(); - - if ((group = consumeTokenGroups(expr_close_regex)).empty()) { - throw std::runtime_error("Expected closing expression tag"); - } - - auto post_space = parsePostSpace(group[1]); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); - } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { - auto pre_space = parsePreSpace(group[1]); - - std::string keyword; - - auto parseBlockClose = [&]() -> SpaceHandling { - if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); - return parsePostSpace(group[1]); - }; - - if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); - - if (keyword == "if") { - auto condition = parseExpression(); - if (!condition) throw std::runtime_error("Expected condition in if block"); - - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); - } else if (keyword == "elif") { - auto condition = parseExpression(); - if (!condition) throw std::runtime_error("Expected condition in elif block"); - - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); - } else if (keyword == "else") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "endif") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "for") { - static std::regex recursive_tok(R"(recursive\b)"); - static std::regex if_tok(R"(if\b)"); - - auto varnames = parseVarNames(); - static std::regex in_tok(R"(in\b)"); - if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); - auto iterable = parseExpression(/* allow_if_expr = */ false); - if (!iterable) throw std::runtime_error("Expected iterable in for block"); - - std::shared_ptr condition; - if (!consumeToken(if_tok).empty()) { - condition = parseExpression(); - } - auto recursive = !consumeToken(recursive_tok).empty(); - - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); - } else if (keyword == "endfor") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "generation") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "endgeneration") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "set") { - static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); - - std::string ns; - std::vector var_names; - std::shared_ptr value; - if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { - ns = group[1]; - var_names.push_back(group[2]); - - if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); - - value = parseExpression(); - if (!value) throw std::runtime_error("Expected value in set block"); - } else { - var_names = parseVarNames(); - - if (!consumeToken("=").empty()) { - value = parseExpression(); - if (!value) throw std::runtime_error("Expected value in set block"); - } - } - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); - } else if (keyword == "endset") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "macro") { - auto macroname = parseIdentifier(); - if (!macroname) throw std::runtime_error("Expected macro name in macro block"); - auto params = parseParameters(); - - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); - } else if (keyword == "endmacro") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "filter") { - auto filter = parseExpression(); - if (!filter) throw std::runtime_error("Expected expression in filter block"); - - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(filter))); - } else if (keyword == "endfilter") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space)); - } else if (keyword == "break" || keyword == "continue") { - auto post_space = parseBlockClose(); - tokens.push_back(std::make_unique(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); - } else { - throw std::runtime_error("Unexpected block: " + keyword); - } - } else if (std::regex_search(it, end, match, non_text_open_regex)) { - if (!match.position()) { - if (match[0] != "{#") - throw std::runtime_error("Internal error: Expected a comment"); - throw std::runtime_error("Missing end of comment tag"); - } - auto text_end = it + match.position(); - text = std::string(it, text_end); - it = text_end; - tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); - } else { - text = std::string(it, end); - it = end; - tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); - } - } - return tokens; - } catch (const std::exception & e) { - throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); - } - } - - std::shared_ptr parseTemplate( - const TemplateTokenIterator & begin, - TemplateTokenIterator & it, - const TemplateTokenIterator & end, - bool fully = false) const { - std::vector> children; - while (it != end) { - const auto start = it; - const auto & token = *(it++); - if (auto if_token = dynamic_cast(token.get())) { - std::vector, std::shared_ptr>> cascade; - cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); - - while (it != end && (*it)->type == TemplateToken::Type::Elif) { - auto elif_token = dynamic_cast((*(it++)).get()); - cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); - } - - if (it != end && (*it)->type == TemplateToken::Type::Else) { - cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); - } - if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { - throw unterminated(**start); - } - children.emplace_back(std::make_shared(token->location, std::move(cascade))); - } else if (auto for_token = dynamic_cast(token.get())) { - auto body = parseTemplate(begin, it, end); - auto else_body = std::shared_ptr(); - if (it != end && (*it)->type == TemplateToken::Type::Else) { - else_body = parseTemplate(begin, ++it, end); - } - if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { - throw unterminated(**start); - } - children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); - } else if (dynamic_cast(token.get())) { - auto body = parseTemplate(begin, it, end); - if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { - throw unterminated(**start); - } - // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). - children.emplace_back(std::move(body)); - } else if (auto text_token = dynamic_cast(token.get())) { - SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; - SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; - - auto text = text_token->text; - if (post_space == SpaceHandling::Strip) { - static std::regex trailing_space_regex(R"(\s+$)"); - text = std::regex_replace(text, trailing_space_regex, ""); - } else if (options.lstrip_blocks && it != end) { - auto i = text.size(); - while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; - if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { - text.resize(i); - } - } - if (pre_space == SpaceHandling::Strip) { - static std::regex leading_space_regex(R"(^\s+)"); - text = std::regex_replace(text, leading_space_regex, ""); - } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { - if (text.length() > 0 && text[0] == '\n') { - text.erase(0, 1); - } - } - if (it == end && !options.keep_trailing_newline) { - auto i = text.size(); - if (i > 0 && text[i - 1] == '\n') { - i--; - if (i > 0 && text[i - 1] == '\r') i--; - text.resize(i); - } - } - children.emplace_back(std::make_shared(token->location, text)); - } else if (auto expr_token = dynamic_cast(token.get())) { - children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); - } else if (auto set_token = dynamic_cast(token.get())) { - if (set_token->value) { - children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); - } else { - auto value_template = parseTemplate(begin, it, end); - if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { - throw unterminated(**start); - } - if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); - if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); - auto & name = set_token->var_names[0]; - children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); - } - } else if (auto macro_token = dynamic_cast(token.get())) { - auto body = parseTemplate(begin, it, end); - if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { - throw unterminated(**start); - } - children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); - } else if (auto filter_token = dynamic_cast(token.get())) { - auto body = parseTemplate(begin, it, end); - if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { - throw unterminated(**start); - } - children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); - } else if (dynamic_cast(token.get())) { - // Ignore comments - } else if (auto ctrl_token = dynamic_cast(token.get())) { - children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); - } else if (dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get()) - || dynamic_cast(token.get())) { - it--; // unconsume the token - break; // exit the loop - } else { - throw unexpected(**(it-1)); - } - } - if (fully && it != end) { - throw unexpected(**it); - } - if (children.empty()) { - return std::make_shared(Location { template_str, 0 }, std::string()); - } else if (children.size() == 1) { - return std::move(children[0]); - } else { - return std::make_shared(children[0]->location(), std::move(children)); - } - } - -public: - - static std::shared_ptr parse(const std::string& template_str, const Options & options) { - Parser parser(std::make_shared(normalize_newlines(template_str)), options); - auto tokens = parser.tokenize(); - TemplateTokenIterator begin = tokens.begin(); - auto it = begin; - TemplateTokenIterator end = tokens.end(); - return parser.parseTemplate(begin, it, end, /* full= */ true); - } -}; - -static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { - std::map named_positions; - for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; - - return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { - auto args_obj = Value::object(); - std::vector provided_args(params.size()); - for (size_t i = 0, n = args.args.size(); i < n; i++) { - auto & arg = args.args[i]; - if (i < params.size()) { - args_obj.set(params[i], arg); - provided_args[i] = true; - } else { - throw std::runtime_error("Too many positional params for " + fn_name); - } - } - for (auto & [name, value] : args.kwargs) { - auto named_pos_it = named_positions.find(name); - if (named_pos_it == named_positions.end()) { - throw std::runtime_error("Unknown argument " + name + " for function " + fn_name); - } - provided_args[named_pos_it->second] = true; - args_obj.set(name, value); - } - return fn(context, args_obj); - }); -} - -inline std::shared_ptr Context::builtins() { - auto globals = Value::object(); - - globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { - throw std::runtime_error(args.at("message").get()); - })); - globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { - return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); - })); - globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { - auto items = Value::array(); - if (args.contains("object")) { - auto & obj = args.at("object"); - if (obj.is_string()) { - auto json_obj = json::parse(obj.get()); - for (const auto & kv : json_obj.items()) { - items.push_back(Value::array({kv.key(), kv.value()})); - } - } else if (!obj.is_null()) { - for (auto & key : obj.keys()) { - items.push_back(Value::array({key, obj.at(key)})); - } - } - } - return items; - })); - globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { - auto items = args.at("items"); - if (!items.is_array()) throw std::runtime_error("object is not a list"); - if (items.size() == 0) return Value(); - return items.at(items.size() - 1); - })); - globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { - auto & text = args.at("text"); - return text.is_null() ? text : Value(strip(text.get())); - })); - globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { - auto text = args.at("text"); - if (text.is_null()) return text; - std::string res; - auto str = text.get(); - std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); - return Value(res); - })); - globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { - args.expectArgs("default", {2, 3}, {0, 1}); - auto & value = args.args[0]; - auto & default_value = args.args[1]; - bool boolean = false; - if (args.args.size() == 3) { - boolean = args.args[2].get(); - } else { - Value bv = args.get_named("boolean"); - if (!bv.is_null()) { - boolean = bv.get(); - } - } - return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; - })); - auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { - return Value(html_escape(args.at("text").get())); - }); - globals.set("e", escape); - globals.set("escape", escape); - globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { - auto sep = args.get("sep", ""); - auto first = std::make_shared(true); - return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { - if (*first) { - *first = false; - return ""; - } - return sep; - }); - return Value(html_escape(args.at("text").get())); - })); - globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { - return Value((int64_t) args.at("items").size()); - })); - globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { - if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); - auto & value = args.at("value"); - auto keys = value.keys(); - std::sort(keys.begin(), keys.end()); - auto res = Value::array(); - for (auto & key : keys) { - res.push_back(Value::array({key, value.at(key)})); - } - return res; - })); - globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { - auto do_join = [](Value & items, const std::string & sep) { - if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); - std::ostringstream oss; - auto first = true; - for (size_t i = 0, n = items.size(); i < n; ++i) { - if (first) first = false; - else oss << sep; - oss << items.at(i).to_str(); - } - return Value(oss.str()); - }; - auto sep = args.get("d", ""); - if (args.contains("items")) { - auto & items = args.at("items"); - return do_join(items, sep); - } else { - return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { - auto & items = args.at("items"); - if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); - return do_join(items, sep); - }); - } - })); - globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { - auto ns = Value::object(); - args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); - for (auto & [name, value] : args.kwargs) { - ns.set(name, value); - } - return ns; - })); - auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { - return args.at("actual") == args.at("expected"); - }); - globals.set("equalto", equalto); - globals.set("==", equalto); - globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { - auto & items = args.at("items"); - return (int64_t) items.size(); - })); - globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { - return args.at("value").to_str(); - })); - globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { - return args.at("value").to_str(); - })); - globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { - return args.at("value").to_int(); - })); - globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { - auto & items = args.at("items"); - if (!items.is_array()) throw std::runtime_error("object is not iterable"); - return items; - })); - globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { - auto & items = args.at("items"); - if (!items.is_array()) throw std::runtime_error("object is not iterable"); - std::unordered_set seen; - auto result = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto pair = seen.insert(items.at(i)); - if (pair.second) { - result.push_back(items.at(i)); - } - } - return result; - })); - auto make_filter = [](const Value & filter, Value & extra_args) -> Value { - return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { - auto & value = args.at("value"); - ArgumentsValue actual_args; - actual_args.args.emplace_back(value); - for (size_t i = 0, n = extra_args.size(); i < n; i++) { - actual_args.args.emplace_back(extra_args.at(i)); - } - return filter.call(context, actual_args); - }); - }; - auto select_or_reject = [make_filter](bool is_select) { - return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - if (items.is_null()) - return Value::array(); - if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); - - auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); - - auto filter_args = Value::array(); - for (size_t i = 2, n = args.args.size(); i < n; i++) { - filter_args.push_back(args.args[i]); - } - auto filter = make_filter(filter_fn, filter_args); - - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - ArgumentsValue filter_args; - filter_args.args.emplace_back(item); - auto pred_res = filter.call(context, filter_args); - if (pred_res.to_bool() == (is_select ? true : false)) { - res.push_back(item); - } - } - return res; - }); - }; - globals.set("select", select_or_reject(/* is_select= */ true)); - globals.set("reject", select_or_reject(/* is_select= */ false)); - globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - auto res = Value::array(); - if (args.args.size() == 1 && - ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { - auto & items = args.args[0]; - auto attr_name = args.get_named("attribute"); - auto default_value = args.get_named("default"); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - auto attr = item.get(attr_name); - res.push_back(attr.is_null() ? default_value : attr); - } - } else if (args.kwargs.empty() && args.args.size() >= 2) { - auto fn = context->get(args.args[1]); - if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); - ArgumentsValue filter_args { {Value()}, {} }; - for (size_t i = 2, n = args.args.size(); i < n; i++) { - filter_args.args.emplace_back(args.args[i]); - } - for (size_t i = 0, n = args.args[0].size(); i < n; i++) { - auto & item = args.args[0].at(i); - filter_args.args[0] = item; - res.push_back(fn.call(context, filter_args)); - } - } else { - throw std::runtime_error("Invalid or unsupported arguments for map"); - } - return res; - })); - globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { - auto text = args.at("text").get(); - auto first = args.get("first", false); - std::string out; - std::string indent(args.get("indent", 0), ' '); - std::istringstream iss(text); - std::string line; - auto is_first = true; - while (std::getline(iss, line, '\n')) { - auto needs_indent = !is_first || first; - if (is_first) is_first = false; - else out += "\n"; - if (needs_indent) out += indent; - out += line; - } - if (!text.empty() && text.back() == '\n') out += "\n"; - return out; - })); - auto select_or_reject_attr = [](bool is_select) { - return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - if (items.is_null()) - return Value::array(); - if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); - auto attr_name = args.args[1].get(); - - bool has_test = false; - Value test_fn; - ArgumentsValue test_args {{Value()}, {}}; - if (args.args.size() >= 3) { - has_test = true; - test_fn = context->get(args.args[2]); - if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); - for (size_t i = 3, n = args.args.size(); i < n; i++) { - test_args.args.emplace_back(args.args[i]); - } - test_args.kwargs = args.kwargs; - } - - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - auto attr = item.get(attr_name); - if (has_test) { - test_args.args[0] = attr; - if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { - res.push_back(item); - } - } else { - res.push_back(attr); - } - } - return res; - }); - }; - globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); - globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); - globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { - std::vector startEndStep(3); - std::vector param_set(3); - if (args.args.size() == 1) { - startEndStep[1] = args.args[0].get(); - param_set[1] = true; - } else { - for (size_t i = 0; i < args.args.size(); i++) { - auto & arg = args.args[i]; - auto v = arg.get(); - startEndStep[i] = v; - param_set[i] = true; - } - } - for (auto & [name, value] : args.kwargs) { - size_t i; - if (name == "start") i = 0; - else if (name == "end") i = 1; - else if (name == "step") i = 2; - else throw std::runtime_error("Unknown argument " + name + " for function range"); - - if (param_set[i]) { - throw std::runtime_error("Duplicate argument " + name + " for function range"); - } - startEndStep[i] = value.get(); - param_set[i] = true; - } - if (!param_set[1]) { - throw std::runtime_error("Missing required argument 'end' for function range"); - } - int64_t start = param_set[0] ? startEndStep[0] : 0; - int64_t end = startEndStep[1]; - int64_t step = param_set[2] ? startEndStep[2] : 1; - - auto res = Value::array(); - if (step > 0) { - for (int64_t i = start; i < end; i += step) { - res.push_back(Value(i)); - } - } else { - for (int64_t i = start; i > end; i += step) { - res.push_back(Value(i)); - } - } - return res; - })); - - return std::make_shared(std::move(globals)); -} - -inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { - return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); -} - -} // namespace minja diff --git a/common/minja/chat-template.hpp b/common/minja/chat-template.hpp new file mode 100644 index 00000000..882ba41b --- /dev/null +++ b/common/minja/chat-template.hpp @@ -0,0 +1,529 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "minja.hpp" +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +struct chat_template_caps { + bool supports_tools = false; + bool supports_tool_calls = false; + bool supports_tool_responses = false; + bool supports_system_role = false; + bool supports_parallel_tool_calls = false; + bool supports_tool_call_id = false; + // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments = false; + // CohereForAI/c4ai-command-r-plus simple variant + bool requires_non_null_content = false; + // MiniMaxAI/MiniMax-Text-01 special + bool requires_typed_content = false; +}; + +struct chat_template_inputs { + nlohmann::ordered_json messages; + nlohmann::ordered_json tools; + bool add_generation_prompt = true; + nlohmann::ordered_json extra_context; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; +}; + +class chat_template { + + private: + chat_template_caps caps_; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; + std::string tool_call_example_; + + std::string try_raw_render( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); + // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); + return prompt; + } catch (const std::exception & e) { + // fprintf(stderr, "try_raw_render error: %s\n", e.what()); + return ""; + } + } + + public: + + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) + { + template_root_ = minja::Parser::parse(source_, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + + auto contains = [](const std::string & haystack, const std::string & needle) { + return haystack.find(needle) != std::string::npos; + }; + + const std::string user_needle = ""; + const std::string sys_needle = ""; + const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}}; + const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}}; + + caps_.requires_typed_content = + !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle) + && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle); + + const auto dummy_user_msg = caps_.requires_typed_content + ? dummy_typed_user_msg + : dummy_str_user_msg; + const json needle_system_msg = { + {"role", "system"}, + {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)}, + }; + + caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle); + + auto out = try_raw_render(json::array({ + dummy_user_msg + }), json::array({ + { + {"name", "some_tool"}, + {"type", "function"}, + {"function", { + {"name", "some_tool"}, + {"description", "Some tool."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"arg", { + {"type", "string"}, + {"description", "Some argument."}, + }}, + }}, + {"required", json::array({ "arg" })}, + }}, + }}, + }, + }), false); + caps_.supports_tools = contains(out, "some_tool"); + + auto make_tool_calls_msg = [&](const json & tool_calls) { + return json { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", tool_calls}, + }; + }; + auto make_tool_call = [](const std::string & tool_name, const json & arguments) { + return json { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", arguments}, + {"name", tool_name}, + }}, + }; + }; + const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; + + // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), + }), {}, false); + auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), + }), {}, false); + auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + + caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; + caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; + auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + + if (caps_.supports_tool_calls) { + auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); + auto tc1 = make_tool_call("test_tool1", dummy_args); + auto tc2 = make_tool_call("test_tool2", dummy_args); + auto out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1, tc2})), + }), {}, false); + caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2"); + + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1})), + { + {"role", "tool"}, + {"name", "test_tool1"}, + {"content", "Some response!"}, + {"tool_call_id", "call_911_"}, + } + }), {}, false); + caps_.supports_tool_responses = contains(out, "Some response!"); + caps_.supports_tool_call_id = contains(out, "call_911_"); + } + + try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json args { + {"arg1", "some_value"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, + }}, + }, + })}, + }; + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); + } + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; + } + } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } + } + + const std::string & source() const { return source_; } + const std::string & bos_token() const { return bos_token_; } + const std::string & eos_token() const { return eos_token_; } + const chat_template_caps & original_caps() const { return caps_; } + + // Deprecated, please use the form with chat_template_inputs and chat_template_options + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + + std::string apply( + const chat_template_inputs & inputs, + const chat_template_options & opts = chat_template_options()) const + { + json actual_messages; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message.contains("role") && message["role"] == "tool") { + has_tool_responses = true; + } + if (message.contains("content") && message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content + ); + + if (needs_polyfills) { + actual_messages = json::array(); + + auto add_message = [&](const json & msg) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + actual_messages.push_back({ + {"role", msg.at("role")}, + {"content", {{ + {"type", "text"}, + {"text", msg.at("content")}, + }}}, + }); + } else { + actual_messages.push_back(msg); + } + }; + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + add_message({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + + json adjusted_messages; + if (polyfill_tools) { + adjusted_messages = add_system(inputs.messages, + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); + } else { + adjusted_messages = inputs.messages; + } + + for (const auto & message_ : adjusted_messages) { + auto message = message_; + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + + if (message.contains("tool_calls")) { + if (polyfill_object_arguments || polyfill_tool_calls) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + auto & arguments = function.at("arguments"); + if (arguments.is_string()) { + try { + arguments = json::parse(arguments.get()); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + } + } + } + } + } + if (polyfill_tool_calls) { + auto content = message.at("content"); + auto tool_calls = json::array(); + for (const auto & tool_call : message.at("tool_calls")) { + if (tool_call.at("type") != "function") { + continue; + } + const auto & function = tool_call.at("function"); + auto tc = json { + {"name", function.at("name")}, + {"arguments", function.at("arguments")}, + }; + if (tool_call.contains("id")) { + tc["id"] = tool_call["id"]; + } + tool_calls.push_back(tc); + } + auto obj = json { + {"tool_calls", tool_calls}, + }; + if (!content.is_null() && content != "") { + obj["content"] = content; + } + message["content"] = obj.dump(2); + message.erase("tool_calls"); + } + } + if (polyfill_tool_responses && role == "tool") { + message["role"] = "user"; + auto obj = json { + {"tool_response", { + {"content", message.at("content")}, + }}, + }; + if (message.contains("name")) { + obj["tool_response"]["name"] = message.at("name"); + } + if (message.contains("tool_call_id")) { + obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); + } + message["content"] = obj.dump(2); + message.erase("name"); + } + + if (!message["content"].is_null() && polyfill_system_role) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + add_message(message); + } + flush_sys(); + } else { + actual_messages = inputs.messages; + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", inputs.add_generation_prompt}, + })); + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } + if (!inputs.tools.is_null()) { + context->set("tools", minja::Value(inputs.tools)); + } + if (!inputs.extra_context.is_null()) { + for (auto & kv : inputs.extra_context.items()) { + context->set(kv.key(), minja::Value(kv.value())); + } + } + + auto ret = template_root_->render(context); + // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str()); + // fprintf(stderr, "apply: %s\n\n", ret.c_str()); + return ret; + } + + static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { + json messages_with_system = messages; + + if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + std::string existing_system = messages_with_system.at(0).at("content"); + messages_with_system[0] = json { + {"role", "system"}, + {"content", existing_system + "\n\n" + system_prompt}, + }; + } else { + messages_with_system.insert(messages_with_system.begin(), json { + {"role", "system"}, + {"content", system_prompt}, + }); + } + return messages_with_system; + } +}; + +} // namespace minja diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp new file mode 100644 index 00000000..c58dd66e --- /dev/null +++ b/common/minja/minja.hpp @@ -0,0 +1,2883 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +inline std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + using CallableType = std::function &, ArgumentsValue &)>; + using FilterType = std::function &, ArgumentsValue &)>; + +private: + using ObjectType = nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { + if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + + Value(const json & v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value pop(const Value& index) { + if (is_array()) { + if (array_->empty()) + throw std::runtime_error("pop from empty list"); + if (index.is_null()) { + auto ret = array_->back(); + array_->pop_back(); + return ret; + } else if (!index.is_number_integer()) { + throw std::runtime_error("pop index must be an integer: " + index.dump()); + } else { + auto i = index.get(); + if (i < 0 || i >= static_cast(array_->size())) + throw std::runtime_error("pop index out of range: " + index.dump()); + auto it = array_->begin() + (i < 0 ? array_->size() + i : i); + auto ret = *it; + array_->erase(it); + return ret; + } + } else if (is_object()) { + if (!index.is_hashable()) + throw std::runtime_error("Unashable type: " + index.dump()); + auto it = object_->find(index.primitive_); + if (it == object_->end()) + throw std::runtime_error("Key not found: " + index.dump()); + auto ret = it->second; + object_->erase(it); + return ret; + } else { + throw std::runtime_error("Value is not an array or object: " + dump()); + } + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) return Value(); + return it->second; + } + return Value(); + } + void set(const Value& key, const Value& value) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr & context, ArgumentsValue & args) const { + if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + void for_each(const std::function & callback) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto & item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.get()) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + throw std::runtime_error("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) return 0; + if (is_boolean()) return get() ? 1 : 0; + if (is_number()) return static_cast(get()); + if (is_string()) { + try { + return std::stol(get()); + } catch (const std::exception &) { + return 0; + } + } + return 0; + } + + bool operator<(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value & value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (!array_) throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) return array_->at(index); + if (is_object()) return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) res.push_back(item); + for (const auto& item : *rhs.array_) res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & [key, value] : kwargs) { + if (key == name) return value; + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } +}; + +template <> +inline json Value::get() const { + if (is_primitive()) return primitive_; + if (is_null()) return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& [key, value] : *object_) { + if (key.is_string()) { + res[key.get()] = value.get(); + } else if (key.is_primitive()) { + res[key.dump()] = value.get(); + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + dump()); +} + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, const Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + using Parameters = std::vector>>; + + Location location; + + Expression(const Location & location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + try { + return do_evaluate(context); + } catch (const std::exception & e) { + std::ostringstream out; + out << e.what(); + if (location.source) out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + case Type::Filter: return "filter"; + case Type::EndFilter: return "endfilter"; + case Type::Generation: return "generation"; + case Type::EndGeneration: return "endgeneration"; + case Type::Break: return "break"; + case Type::Continue: return "continue"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) + : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct GenerationTemplateToken : public TemplateToken { + GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} +}; + +struct EndGenerationTemplateToken : public TemplateToken { + EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +enum class LoopControlType { Break, Continue }; + +class LoopControlException : public std::runtime_error { +public: + LoopControlType control_type; + LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} + LoopControlException(LoopControlType control_type) + : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")), + control_type(control_type) {} +}; + +struct LoopControlTemplateToken : public TemplateToken { + LoopControlType control_type; + LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} +}; + +class TemplateNode { + Location location_; +protected: + virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + void render(std::ostringstream & out, const std::shared_ptr & context) const { + try { + do_render(out, context); + } catch (const LoopControlException & e) { + // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop. + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw LoopControlException(err.str(), e.control_type); + } catch (const std::exception & e) { + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & location, std::vector> && c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + void do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; +public: + ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector, std::shared_ptr>> cascade; +public: + IfNode(const Location & location, std::vector, std::shared_ptr>> && c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); + branch.second->render(out, context); + return; + } + } + } +}; + +class LoopControlNode : public TemplateNode { + LoopControlType control_type_; + public: + LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} + void do_render(std::ostringstream &, const std::shared_ptr &) const override { + throw LoopControlException(control_type_); + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; +public: + ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) + : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) throw std::runtime_error("ForNode.iterable is null"); + if (!body) throw std::runtime_error("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); + } + iterable_value.for_each([&](Value & item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + try { + body->render(out, loop_context); + } catch (const LoopControlException & e) { + if (e.control_type == LoopControlType::Break) break; + if (e.control_type == LoopControlType::Continue) continue; + } + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) throw std::runtime_error("MacroNode.name is null"); + if (!body) throw std::runtime_error("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto & [arg_name, value] : args.kwargs) { + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + +public: + FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!filter) throw std::runtime_error("FilterNode.filter is null"); + if (!body) throw std::runtime_error("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; +public: + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) throw std::runtime_error("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; +public: + SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & location, std::vector> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) throw std::runtime_error("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::shared_ptr>> elements; +public: + DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& [key, value] : elements) { + if (!key) throw std::runtime_error("Dict key is null"); + if (!value) throw std::runtime_error("Dict value is null"); + result.set(key->evaluate(context), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::shared_ptr start, end; + SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr &) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; +public: + SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!base) throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) throw std::runtime_error("SubscriptExpr.index is null"); + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + throw std::runtime_error("Expansion operator is only supported in function calls and collections"); + + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::shared_ptr left; + std::shared_ptr right; + Op op; +public: + BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_iterable(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return l; + return right->evaluate(context); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr & context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (auto un_expr = std::dynamic_pointer_cast(arg)) { + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + throw std::runtime_error("Expansion operator only supported on arrays"); + } + array.for_each([&](Value & value) { + vargs.args.push_back(value); + }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + throw std::runtime_error("ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value & key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& [name, value] : this->kwargs) { + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string & s) { + auto start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return ""; + auto end = s.find_last_not_of(" \t\n\r"); + return s.substr(start, end - start + 1); +} + +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; +public: + MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) throw std::runtime_error("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {1, 1}, {0, 0}); + return obj.pop(vargs.args[0]); + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(str)); + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); + else res[i] = std::tolower(res[i]); + } + return res; + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { +public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("CallExpr.object is null"); + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & location, std::vector> && p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) throw std::runtime_error("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { +private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return std::make_unique(std::move(result)); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return std::make_shared(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto [condition, else_expr] = parseIfExpression(); + return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::shared_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) throw std::runtime_error("Expected 'else' expression"); + } + return std::pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'or' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) throw std::runtime_error("Expected right side of 'and' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else throw std::runtime_error("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) return expr; + if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression"); + return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({ "]" })) { + index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::shared_ptr>> elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken & token) const { + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken & token) const { + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); + static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + std::smatch match; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_unique(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) throw std::runtime_error("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "generation") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endgeneration") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) throw std::runtime_error("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "break" || keyword == "continue") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + throw std::runtime_error("Internal error: Expected a comment"); + throw std::runtime_error("Missing end of comment tag"); + } + auto text_end = it + match.position(); + text = std::string(it, text_end); + it = text_end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + text = std::string(it, end); + it = end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } + } + return tokens; + } catch (const std::exception & e) { + throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, std::shared_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(cascade))); + } else if (auto for_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { + throw unterminated(**start); + } + // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). + children.emplace_back(std::move(body)); + } else if (auto text_token = dynamic_cast(token.get())) { + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"(\s+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + auto i = text.size(); + while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; + if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { + text.resize(i); + } + } + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^\s+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { + if (text.length() > 0 && text[0] == '\n') { + text.erase(0, 1); + } + } + if (it == end && !options.keep_trailing_newline) { + auto i = text.size(); + if (i > 0 && text[i - 1] == '\n') { + i--; + if (i > 0 && text[i - 1] == '\r') i--; + text.resize(i); + } + } + children.emplace_back(std::make_shared(token->location, text)); + } else if (auto expr_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); + } else if (auto set_token = dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); + } + } else if (auto macro_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto filter_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); + } else if (dynamic_cast(token.get())) { + // Ignore comments + } else if (auto ctrl_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); + } else if (dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it-1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return std::make_shared(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), std::move(children)); + } + } + +public: + + static std::shared_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(normalize_newlines(template_str)), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (auto & [name, value] : args.kwargs) { + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + name + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (obj.is_string()) { + auto json_obj = json::parse(obj.get()); + for (const auto & kv : json_obj.items()) { + items.push_back(Value::array({kv.key(), kv.value()})); + } + } else if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not a list"); + if (items.size() == 0) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); + return Value(res); + })); + globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto & value = args.args[0]; + auto & default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); + for (auto & [name, value] : args.kwargs) { + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + auto select_or_reject = [make_filter](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (pred_res.to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } + return res; + }); + }; + globals.set("select", select_or_reject(/* is_select= */ true)); + globals.set("reject", select_or_reject(/* is_select= */ false)); + globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + ArgumentsValue filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent(args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) is_first = false; + else out += "\n"; + if (needs_indent) out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') out += "\n"; + return out; + })); + auto select_or_reject_attr = [](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + }); + }; + globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); + globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); + globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") i = 0; + else if (name == "end") i = 1; + else if (name == "step") i = 2; + else throw std::runtime_error("Unknown argument " + name + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d354..cf8659b0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,7 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" -#include "chat-template.hpp" +#include "chat.h" #include #include @@ -158,7 +158,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); - auto chat_templates = common_chat_templates_from_model(model, params.chat_template); + auto chat_templates = common_chat_templates_init(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -201,7 +201,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; + const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get()); if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -264,9 +264,11 @@ int main(int argc, char ** argv) { std::vector embd_inp; auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back({role, content, {}}); + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; @@ -755,11 +757,14 @@ int main(int argc, char ** argv) { // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { - if (params.interactive) { - is_interacting = true; + for (auto token : antiprompt_token) { + if (token == last_token) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; } - is_antiprompt = true; } if (is_antiprompt) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9362da22..ed8644ef 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -24,7 +24,7 @@ #include #include -#include "chat-template.hpp" +#include "chat.h" #include "common.h" #include "json.hpp" #include "linenoise.cpp/linenoise.h" @@ -557,7 +557,7 @@ class LlamaData { llama_model_ptr model; llama_sampler_ptr sampler; llama_context_ptr context; - std::vector messages; + std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; @@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { - if (use_jinja) { - json messages = json::array(); - for (const auto & msg : llama_data.messages) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); - } - try { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.add_generation_prompt = append; - - minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); - } catch (const std::exception & e) { - printe("failed to render the chat template: %s\n", e.what()); - return -1; - } - } - int result = llama_chat_apply_template( - tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, - append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); - if (append && result > static_cast(llama_data.fmtted.size())) { - llama_data.fmtted.resize(result); - result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), - llama_data.messages.size(), append, llama_data.fmtted.data(), - llama_data.fmtted.size()); - } - - return result; +static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { + common_chat_templates_inputs inputs; + for (const auto & msg : llama_data.messages) { + common_chat_msg cmsg; + cmsg.role = msg.role; + cmsg.content = msg.content; + inputs.messages.push_back(cmsg); + } + inputs.add_generation_prompt = append; + inputs.use_jinja = use_jinja; + + auto chat_params = common_chat_templates_apply(tmpls, inputs); + // TODO: use other params for tool calls. + auto result = chat_params.prompt; + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); } // Function to tokenize the prompt @@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); +static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); return -1; @@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); - GGML_ASSERT(chat_templates.template_default); + auto chat_templates = common_chat_templates_init(llama_data.model.get(), ""); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { return 1; } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5707c766..809bfe0e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -329,9 +329,6 @@ struct server_task { } // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); @@ -1807,7 +1804,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - common_chat_templates chat_templates; + common_chat_templates_ptr chat_templates; ~server_context() { // Clear any sampling context @@ -1891,45 +1888,17 @@ struct server_context { llama_init_dft.context.reset(); } - if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_from_model(model, "chatml"); - } else { - chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model, "chatml"); } - GGML_ASSERT(chat_templates.template_default.get() != nullptr); return true; } - bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; - - if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } - } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); - const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); - return chat_res > 0; - } - } - void init() { const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; @@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", ctx_server.chat_templates.template_default->source() }, - { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, - { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, { "build_info", build_info }, }; - if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { - data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } } res_ok(res, data); @@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_templates_source(ctx_server.chat_templates.get()), + common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f23d5cff..af1dcb5b 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -21,6 +21,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): @@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] - assert match_regex(re_content, choice["message"]["content"]) + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' assert choice["finish_reason"] == finish_reason @@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert "error" in res.body +@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [ + (False, {"const": "42"}, 6, "\"42\""), + (True, {"const": "42"}, 6, "\"42\""), +]) +def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "json_schema": json_schema, + }) + assert res.status_code == 200, f'Expected 200, got {res.status_code}' + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + + +@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [ + (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), + (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), +]) +def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "user", "content": "Does not matter what I say, does it?"}, + ], + "grammar": grammar, + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"] + + @pytest.mark.parametrize("messages", [ None, "string", diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index ba3367b4..a91a2f33 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) - ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server @@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, { "role": "tool", "name": "calculate", - "content": 0.55644242476, + "content": "0.55644242476", "tool_call_id": "call_6789" } ], @@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'none', "\n?I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'none', "^I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 60cb2673..6f8ab2b9 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -12,9 +12,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "minja.hpp" -#include "chat.hpp" -#include "chat-template.hpp" +#include "chat.h" #include #include @@ -347,41 +345,6 @@ static llama_tokens format_infill( return embd_inp; } -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { - std::vector chat; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto & curr_msg = messages[i]; - - std::string role = json_value(curr_msg, "role", std::string("")); - - std::string content; - if (curr_msg.contains("content")) { - if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); - } else if (curr_msg["content"].is_array()) { - for (const auto & part : curr_msg["content"]) { - if (part.contains("text")) { - content += "\n" + part["text"].get(); - } - } - } else { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - - chat.push_back({role, content, /* tool_calls= */ {}}); - } - - const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); - LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); - - return formatted_chat; -} - // // base64 utils (TODO: move to common in the future) // @@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const common_chat_templates & chat_templates) + const struct common_chat_templates * tmpls) { json llama_params; - const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use - ? *chat_templates.template_tool_use - : *chat_templates.template_default; auto tools = json_value(body, "tools", json()); auto stream = json_value(body, "stream", false); @@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + // Handle "response_format" field if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { - llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { json json_schema = json_value(response_format, "json_schema", json::object()); - llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + json_schema = json_value(json_schema, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.add_generation_prompt = true; + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + // Apply chat template to the list of messages - if (use_jinja) { - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { - throw std::runtime_error("Invalid tool_choice: " + tool_choice); - } - if (tool_choice != "none" && llama_params.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - common_chat_inputs inputs; - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - inputs.parallel_tool_calls = false; - } - inputs.stream = stream; - // TODO: support mixing schema w/ tools beyond generic format. - inputs.json_schema = json_value(llama_params, "json_schema", json()); - auto chat_params = common_chat_params_init(tmpl, inputs); - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1..9231c517 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,13 +1,14 @@ #include #include #include +#include #undef NDEBUG #include #include "llama.h" #include "common.h" -#include "chat-template.hpp" +#include "chat.h" static std::string normalize_newlines(const std::string & s) { #ifdef _WIN32 @@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) { #endif } +static common_chat_msg simple_msg(const std::string & role, const std::string & content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + return msg; +} + int main(void) { std::vector conversation { {"system", "You are a helpful assistant"}, @@ -50,7 +58,7 @@ int main(void) { /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .expected_output_jinja= */ "", - /* .bos_token= */ "", + /* .bos_token= */ "", /* .eos_token= */ "", }, { @@ -72,8 +80,8 @@ int main(void) { { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -87,7 +95,7 @@ int main(void) { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -304,12 +312,9 @@ int main(void) { } } - json messages = json::array(); + std::vector messages; for (const auto & msg : conversation) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); + messages.push_back(simple_msg(msg.role, msg.content)); } for (const auto & test_case : test_cases) { if (!test_case.supported_with_jinja) { @@ -317,8 +322,13 @@ int main(void) { } printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); try { - minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); - auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token); + common_chat_templates_inputs inputs; + inputs.use_jinja = true; + inputs.messages = messages; + inputs.add_generation_prompt = add_generation_prompt; + auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; + output = normalize_newlines(output); auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); if (output != expected_output) { printf("Expected:\n%s\n", expected_output.c_str()); @@ -336,11 +346,11 @@ int main(void) { // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; - common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; + auto sys_msg = simple_msg("system", "You are a helpful assistant"); auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str); + auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; @@ -360,14 +370,14 @@ int main(void) { // test llama_chat_format_single for user message printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); - chat2.push_back({"system", "You are a helpful assistant", {}}); - chat2.push_back({"user", "Hello", {}}); - chat2.push_back({"assistant", "I am assistant", {}}); - common_chat_msg new_msg{"user", "How are you", {}}; + chat2.push_back(simple_msg("system", "You are a helpful assistant")); + chat2.push_back(simple_msg("user", "Hello")); + chat2.push_back(simple_msg("assistant", "I am assistant")); + auto new_msg = simple_msg("user", "How are you"); - auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + auto fmt_single = [&](const std::string & tmpl_str) { + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()); + auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 2836caf6..64359230 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -10,38 +10,12 @@ #include #include -#include "chat-template.hpp" -#include "chat.hpp" +#include "chat.h" #include "llama-grammar.h" #include "unicode.h" using json = nlohmann::ordered_json; -static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret; - ret.role = "assistant"; - if (message.contains("content") && !message.at("content").is_null()) { - ret.content = message.at("content"); - } - if (message.contains("tool_plan")) { - ret.reasoning_content = message.at("tool_plan"); - } - if (message.contains("reasoning_content")) { - ret.reasoning_content = message.at("reasoning_content"); - } - auto has_tool_calls = message.contains("tool_calls"); - if (has_tool_calls) { - for (const auto & tc : message.at("tool_calls")) { - const auto & arguments = tc.at("function").at("arguments"); - ret.tool_calls.push_back({ - tc.at("function").at("name").get(), - arguments.is_string() ? arguments.get() : arguments.dump(), - tc.contains("id") ? tc.at("id").get() : "", - }); - } - } - return ret; -} template static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { @@ -53,7 +27,7 @@ template static void assert_equals(const T & expected, const T & actua } static std::string read_file(const std::string & path) { - std::cerr << "# Reading: " << path << std::endl << std::flush; + std::cerr << "# Reading: " << path << '\n' << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -66,10 +40,14 @@ static std::string read_file(const std::string & path) { fs.seekg(0); std::string out; out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + fs.read(out.data(), static_cast(size)); return out; } +static common_chat_templates_ptr read_templates(const std::string & path) { + return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path))); +} + static std::unique_ptr build_grammar(const std::string & grammar_str) { return std::unique_ptr( llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); @@ -90,110 +68,102 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { } } - for (const auto & stack : stacks_cur) { - if (stack.empty()) { - // An empty stack means that the grammar has been completed - return true; - } + if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) { + // An empty stack means that the grammar has been completed + return true; } return false; } -// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`. -static std::string dump(const json & j) { - return minja::Value(j).dump(-1, /* to_json= */ true); -} - static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.content_parts.size(), actual.content_parts.size()); + for (size_t i = 0; i < expected.content_parts.size(); i++) { + const auto & expected_part = expected.content_parts[i]; + const auto & actual_part = actual.content_parts[i]; + assert_equals(expected_part.type, actual_part.type); + assert_equals(expected_part.text, actual_part.text); + } assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; const auto & actual_tool_call = actual.tool_calls[i]; assert_equals(expected_tool_call.name, actual_tool_call.name); - assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments))); + assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump()); assert_equals(expected_tool_call.id, actual_tool_call.id); } } -const auto special_function_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "special_function", - "description": "I'm special", - "parameters": { - "type": "object", - "properties": { - "arg1": { - "type": "integer", - "description": "The arg." - } - }, - "required": ["arg1"] - } - } -})"); -const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const auto code_interpreter_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "code_interpreter", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const json tools = { special_function_tool, python_tool }; -const json llama_3_1_tools = { special_function_tool, code_interpreter_tool }; +common_chat_tool special_function_tool { + /* .name = */ "special_function", + /* .description = */ "I'm special", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + })", +}; +common_chat_tool python_tool { + /* .name = */ "python", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +common_chat_tool code_interpreter_tool { + /* .name = */ "code_interpreter", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +std::vector tools { special_function_tool, python_tool }; +std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool }; struct delta_data { std::string delta; common_chat_params params; }; -static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & user_message, const json & delta_message, const json & tools, - const json & tool_choice, +static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & user_message, + const common_chat_msg & delta_message, + const std::vector & tools, + const common_chat_tool_choice & tool_choice, bool think = false) { - common_chat_inputs inputs; + common_chat_templates_inputs inputs; inputs.parallel_tool_calls = true; - inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; inputs.extract_reasoning = think; - auto params_prefix = common_chat_params_init(tmpl, inputs); + auto params_prefix = common_chat_templates_apply(tmpls, inputs); inputs.messages.push_back(delta_message); inputs.add_generation_prompt = false; - auto params_full = common_chat_params_init(tmpl, inputs); + auto params_full = common_chat_templates_apply(tmpls, inputs); std::string prefix = params_prefix.prompt; std::string full = params_full.prompt; @@ -234,30 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto gets the diff, removes any end tokens and parses the result w/ the grammar, checking that the parsed message is the same as the test_message */ -static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & test_message, const json & tools = {}, const std::string & expected_delta = "", +static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & test_message, + const std::vector & tools = {}, + const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, bool think = false) { - common_chat_msg expected_msg = msg_from_json(test_message); - - auto user_message = json{ - { "role", "user" }, - { "content", "Hello, world!" } - }; + common_chat_msg user_message; + user_message.role = "user"; + user_message.content = "Hello, world!"; - for (const auto & tool_choice : json({ "auto", "required" })) { - auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think); + for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { + auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think); if (!expected_delta.empty()) { assert_equals(expected_delta, data.delta); } if (expect_grammar_triggered) { const auto msg = common_chat_parse(data.delta, data.params.format); - assert_msg_equals(expected_msg, msg); + assert_msg_equals(test_message, msg); } - if (!expected_msg.tool_calls.empty()) { + if (!test_message.tool_calls.empty()) { GGML_ASSERT(!data.params.grammar.empty()); } if (!data.params.grammar.empty()) { @@ -297,246 +266,339 @@ static void test_template(const common_chat_template & tmpl, const std::vectorI'm thinkingHello, world!\nWhat's up?" }, - }; - json message_assist_thoughts_unparsed_r7b { - { "role", "assistant" }, - { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" }, - }; - json message_assist_thoughts { - { "role", "assistant" }, - { "content", "Hello, world!\nWhat's up?" }, - { "reasoning_content", "I'm thinking" }, - }; - json tool_calls = json::array({{ - { "type", "function" }, - { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, - }}); - - json message_assist_call { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_thoughts = { - { "role", "assistant" }, - { "content", nullptr }, - { "reasoning_content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_thoughts_unparsed = { - { "role", "assistant" }, - { "content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, - }; - json message_assist_call_id { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - {"id", "123456789"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } - }; - json message_assist_call_idx { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - // Index of the tool call in the tool_calls array - {"id", "0"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } - }; - json message_assist_call_tool_plan_idx = message_assist_call_idx; - message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking"; - - auto python_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "python" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } +const common_chat_msg message_user { + "user", + "Hey there!", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +const common_chat_msg message_user_parts { + "user", + /* .content = */ "", + /* .content_parts = */ { + { "text", "Hey" }, + { "text", "there" }, + }, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_think { + "assistant", + "I'm thinkingHello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_r7b { + "assistant", + "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_thoughts { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "I'm thinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const std::vector tool_calls { + { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, +}; +const std::vector tool_calls_idx { + { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, +}; +const std::vector tool_calls_id { + { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, +}; + +const common_chat_msg message_assist_call { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts = { + "assistant", + /* .content = */ "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "I'm\nthinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts_unparsed = { + "assistant", + /* .content = */ "I'm\nthinking", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_id { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_id, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_idx { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_idx, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_python { + "assistant", + "", + /* .content_parts = */ {}, + { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_code_interpreter { + "assistant", + "", + /* .content_parts = */ {}, + { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +static void test_msgs_oaicompat_json_conversion() { + std::vector msgs{ + message_user, + message_user_parts, + message_assist_call, + message_assist_call_thoughts, + message_assist_call_thoughts_unparsed, + message_assist_call_id, + message_assist_call_idx, + message_assist_call_python, + message_assist_call_code_interpreter, }; - auto code_interpreter_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "code_interpreter" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } + for (const auto & msg : msgs) { + auto oai_json = common_chat_msgs_to_json_oaicompat({msg}); + auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); + assert_equals((size_t) 1, msgs2.size()); + auto msg2 = msgs2[0]; + assert_msg_equals(msg, msg2); + } + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"user\",\n" + " \"content\": [\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"Hey\"\n" + " },\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"there\"\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2)); + + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"python\",\n" + " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n" + " }\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); +} + +static void test_tools_oaicompat_json_conversion() { + std::vector tools{ + special_function_tool, + python_tool, + code_interpreter_tool, }; - common_chat_inputs inputs_no_tools; - inputs_no_tools.messages = json::array({message_user}); + for (const auto & tool : tools) { + auto oai_json = common_chat_tools_to_json_oaicompat({tool}); + auto tools2 = common_chat_tools_parse_oaicompat(oai_json); + assert_equals((size_t) 1, tools2.size()); + auto tool2 = tools2[0]; + assert_equals(tool.name, tool2.name); + assert_equals(tool.description, tool2.description); + assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2)); + } + + assert_equals( + std::string( + "[\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]" + ), + common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2)); +} + +static void test_template_output_parsers() { + + common_chat_templates_inputs inputs_no_tools; + inputs_no_tools.messages = {message_user}; inputs_no_tools.extract_reasoning = false; - common_chat_inputs inputs_no_tools_think; - inputs_no_tools_think.messages = json::array({message_user}); + common_chat_templates_inputs inputs_no_tools_think; + inputs_no_tools_think.messages = {message_user}; inputs_no_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools; - inputs_tools.messages = json::array({message_user}); - inputs_tools.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools; + inputs_tools.messages = {message_user}; + inputs_tools.tools = {special_function_tool}; inputs_tools.extract_reasoning = false; - common_chat_inputs inputs_tools_think; - inputs_tools_think.messages = json::array({message_user}); - inputs_tools_think.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools_think; + inputs_tools_think.messages = {message_user}; + inputs_tools_think.tools = {special_function_tool}; inputs_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools_builtin; - inputs_tools_builtin.messages = json::array({message_user}); - inputs_tools_builtin.tools = json::array({python_tool}); + common_chat_templates_inputs inputs_tools_builtin; + inputs_tools_builtin.messages = {message_user}; + inputs_tools_builtin.tools = {python_tool}; inputs_tools_builtin.extract_reasoning = false; { // Not supported yet - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "", ""); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); } { - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call_idx, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" "]<|END_ACTION|>"); - test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools, - "<|START_THINKING|>I'm thinking<|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* think= */ true); - test_template(tmpl, end_tokens, message_assist, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", /* expect_grammar_triggered= */ false); } { - const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); + auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_GENERIC, - common_chat_params_init( - common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(), inputs_tools) .format); // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse("{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", - common_chat_params_init(tmpl, inputs_tools).format)); - test_template(tmpl, end_tokens, message_assist_call_id, tools, + common_chat_templates_apply(tmpls.get(), inputs_tools).format)); + test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" " {\n" @@ -550,143 +612,133 @@ static void test_template_output_parsers() { "}"); } { - const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template( - tmpl, end_tokens, message_assist_call_id, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates( + tmpls.get(), end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { - const common_chat_template tmpl( - read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); std::vector end_tokens{ "<|im_end|>" }; - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(), inputs_tools) .format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), + common_chat_templates_apply( + read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(), inputs_tools) .format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init(tmpl, inputs_tools_builtin).format); + common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init( - common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(), inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools, + // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "<|python_tag|>python.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - common_chat_params_init(tmpl, inputs_tools).format); + common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, {}, + test_templates(tmpls.get(), end_tokens, message_assist, {}, "all\n" "Hello, world!\n" "What's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "special_function\n" "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"); std::vector end_tokens{ "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. - const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - // test_template(tmpl, end_tokens, message_assist_call, tools, + // test_templates(tmpls.get(), end_tokens, message_assist_call, tools, // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" // "```json\n" // "{\"arg1\": 1}\n" @@ -697,23 +749,22 @@ static void test_template_output_parsers() { } { // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. - const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), + assert_msg_equals(message_assist_call_thoughts_unparsed, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -721,7 +772,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts), + assert_msg_equals(message_assist_call_thoughts, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -729,7 +780,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" "{\"arg1\": 1}\n" @@ -738,38 +789,46 @@ static void test_template_output_parsers() { } int main(int argc, char ** argv) { + try { #ifndef _WIN32 - if (argc > 1) { - common_chat_inputs inputs; - inputs.messages = { - { { "role", "user" }, { "content", "Hey" } } - }; - inputs.tools = json::array({ special_function_tool }); - - std::cout << "| Template | Format |\n"; - std::cout << "|----------|--------|\n"; - - for (int i = 1; i < argc; i++) { - try { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << std::endl; - continue; + if (argc > 1) { + common_chat_templates_inputs inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "Hey"; + inputs.messages = {msg}; + inputs.tools = { special_function_tool }; + + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << '\n'; + continue; + } + auto tmpls = read_templates(path); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n'; } - common_chat_template tmpl(read_file(path), "", ""); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format); - std::cout << "| " << name << " | " << format << " |\n"; - } catch (const std::exception & e) { - std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; } - } - } else + } else #endif - { - test_template_output_parsers(); - std::cout << "\n[chat] All tests passed!" << std::endl; + { + test_msgs_oaicompat_json_conversion(); + test_tools_oaicompat_json_conversion(); + test_template_output_parsers(); + std::cout << "\n[chat] All tests passed!" << '\n'; + } + return 0; + } catch (const std::exception & e) { + std::cerr << "Error: " << e.what() << '\n'; + return 1; } - return 0; }