* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type
* addressed clang-tidy lints in [test-]chat.*
* rm minja deps from util & common & move it to common/minja/
* add name & tool_call_id to common_chat_msg
* add common_chat_tool
* added json <-> tools, msgs conversions to chat.h
* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)
* fix deepseek r1 slow test (no longer <think> opening w/ new template)
* allow empty tools w/ auto + grammar
* fix & test server grammar & json_schema params w/ & w/o --jinja
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat.cpp \
- common/chat.hpp \
+ common/chat.h \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
arg.h
base64.hpp
chat.cpp
- chat.hpp
- chat-template.hpp
+ chat.h
common.cpp
common.h
console.cpp
llguidance.cpp
log.cpp
log.h
- minja.hpp
+ minja/chat-template.hpp
+ minja/minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp
#include "log.h"
#include "sampling.h"
+#include "chat.h"
#include <algorithm>
#include <climits>
+++ /dev/null
-/*
- Copyright 2024 Google LLC
-
- Use of this source code is governed by an MIT-style
- license that can be found in the LICENSE file or at
- https://opensource.org/licenses/MIT.
-*/
-// SPDX-License-Identifier: MIT
-#pragma once
-
-#include "minja.hpp"
-#include <json.hpp>
-#include <string>
-#include <vector>
-
-using json = nlohmann::ordered_json;
-
-namespace minja {
-
-struct chat_template_caps {
- bool supports_tools = false;
- bool supports_tool_calls = false;
- bool supports_tool_responses = false;
- bool supports_system_role = false;
- bool supports_parallel_tool_calls = false;
- bool supports_tool_call_id = false;
- // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
- // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
- bool requires_object_arguments = false;
- // CohereForAI/c4ai-command-r-plus simple variant
- bool requires_non_null_content = false;
- // MiniMaxAI/MiniMax-Text-01 special
- bool requires_typed_content = false;
-};
-
-struct chat_template_inputs {
- nlohmann::ordered_json messages;
- nlohmann::ordered_json tools;
- bool add_generation_prompt = true;
- nlohmann::ordered_json extra_context;
- std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
-};
-
-struct chat_template_options {
- bool apply_polyfills = true;
- bool use_bos_token = true;
- bool use_eos_token = true;
- bool define_strftime_now = true;
-
- bool polyfill_tools = true;
- bool polyfill_tool_call_examples = true;
- bool polyfill_tool_calls = true;
- bool polyfill_tool_responses = true;
- bool polyfill_system_role = true;
- bool polyfill_object_arguments = true;
- bool polyfill_typed_content = true;
-};
-
-class chat_template {
-
- private:
- chat_template_caps caps_;
- std::string source_;
- std::string bos_token_;
- std::string eos_token_;
- std::shared_ptr<minja::TemplateNode> template_root_;
- std::string tool_call_example_;
-
- std::string try_raw_render(
- const nlohmann::ordered_json & messages,
- const nlohmann::ordered_json & tools,
- bool add_generation_prompt,
- const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
- {
- try {
- chat_template_inputs inputs;
- inputs.messages = messages;
- inputs.tools = tools;
- inputs.add_generation_prompt = add_generation_prompt;
- inputs.extra_context = extra_context;
- // Use fixed date for tests
- inputs.now = std::chrono::system_clock::from_time_t(0);
-
- chat_template_options opts;
- opts.apply_polyfills = false;
-
- auto prompt = apply(inputs, opts);
- // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
- return prompt;
- } catch (const std::exception & e) {
- // fprintf(stderr, "try_raw_render error: %s\n", e.what());
- return "";
- }
- }
-
- public:
-
- chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
- : source_(source), bos_token_(bos_token), eos_token_(eos_token)
- {
- template_root_ = minja::Parser::parse(source_, {
- /* .trim_blocks = */ true,
- /* .lstrip_blocks = */ true,
- /* .keep_trailing_newline = */ false,
- });
-
- auto contains = [](const std::string & haystack, const std::string & needle) {
- return haystack.find(needle) != std::string::npos;
- };
-
- const std::string user_needle = "<User Needle>";
- const std::string sys_needle = "<System Needle>";
- const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
- const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
-
- caps_.requires_typed_content =
- !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
- && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
-
- const auto dummy_user_msg = caps_.requires_typed_content
- ? dummy_typed_user_msg
- : dummy_str_user_msg;
- const json needle_system_msg = {
- {"role", "system"},
- {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
- };
-
- caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
-
- auto out = try_raw_render(json::array({
- dummy_user_msg
- }), json::array({
- {
- {"name", "some_tool"},
- {"type", "function"},
- {"function", {
- {"name", "some_tool"},
- {"description", "Some tool."},
- {"parameters", {
- {"type", "object"},
- {"properties", {
- {"arg", {
- {"type", "string"},
- {"description", "Some argument."},
- }},
- }},
- {"required", json::array({ "arg" })},
- }},
- }},
- },
- }), false);
- caps_.supports_tools = contains(out, "some_tool");
-
- auto make_tool_calls_msg = [&](const json & tool_calls) {
- return json {
- {"role", "assistant"},
- {"content", nullptr},
- {"tool_calls", tool_calls},
- };
- };
- auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
- return json {
- {"id", "call_1___"},
- {"type", "function"},
- {"function", {
- {"arguments", arguments},
- {"name", tool_name},
- }},
- };
- };
- const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
-
- // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
- out = try_raw_render(json::array({
- dummy_user_msg,
- make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
- }), {}, false);
- auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
- out = try_raw_render(json::array({
- dummy_user_msg,
- make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
- }), {}, false);
- auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
-
- caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
- caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
- auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
- auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
- caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
-
- if (caps_.supports_tool_calls) {
- auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
- auto tc1 = make_tool_call("test_tool1", dummy_args);
- auto tc2 = make_tool_call("test_tool2", dummy_args);
- auto out = try_raw_render(json::array({
- dummy_user_msg,
- make_tool_calls_msg(json::array({tc1, tc2})),
- }), {}, false);
- caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
-
- out = try_raw_render(json::array({
- dummy_user_msg,
- make_tool_calls_msg(json::array({tc1})),
- {
- {"role", "tool"},
- {"name", "test_tool1"},
- {"content", "Some response!"},
- {"tool_call_id", "call_911_"},
- }
- }), {}, false);
- caps_.supports_tool_responses = contains(out, "Some response!");
- caps_.supports_tool_call_id = contains(out, "call_911_");
- }
-
- try {
- if (!caps_.supports_tools) {
- const json user_msg {
- {"role", "user"},
- {"content", "Hey"},
- };
- const json args {
- {"arg1", "some_value"},
- };
- const json tool_call_msg {
- {"role", "assistant"},
- {"content", nullptr},
- {"tool_calls", json::array({
- {
- // TODO: detect if requires numerical id or fixed length == 6 like Nemo
- {"id", "call_1___"},
- {"type", "function"},
- {"function", {
- {"name", "tool_name"},
- {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
- }},
- },
- })},
- };
- std::string prefix, full;
- {
- chat_template_inputs inputs;
- inputs.messages = json::array({user_msg});
- inputs.add_generation_prompt = true;
- prefix = apply(inputs);
- }
- {
- chat_template_inputs inputs;
- inputs.messages = json::array({user_msg, tool_call_msg});
- inputs.add_generation_prompt = false;
- full = apply(inputs);
- }
- auto eos_pos_last = full.rfind(eos_token_);
- if (eos_pos_last == prefix.size() - eos_token_.size() ||
- (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
- full = full.substr(0, eos_pos_last);
- }
- size_t common_prefix_length = 0;
- for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
- if (prefix[i] != full[i]) {
- break;
- }
- if (prefix[i] == '<') {
- // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
- // but it removes thinking tags for past messages.
- // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
- continue;
- }
- common_prefix_length = i + 1;
- }
- auto example = full.substr(common_prefix_length);
- if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
- fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
- } else {
- tool_call_example_ = example;
- }
- }
- } catch (const std::exception & e) {
- fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
- }
- }
-
- const std::string & source() const { return source_; }
- const std::string & bos_token() const { return bos_token_; }
- const std::string & eos_token() const { return eos_token_; }
- const chat_template_caps & original_caps() const { return caps_; }
-
- // Deprecated, please use the form with chat_template_inputs and chat_template_options
- std::string apply(
- const nlohmann::ordered_json & messages,
- const nlohmann::ordered_json & tools,
- bool add_generation_prompt,
- const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
- bool apply_polyfills = true)
- {
- fprintf(stderr, "[%s] Deprecated!\n", __func__);
- chat_template_inputs inputs;
- inputs.messages = messages;
- inputs.tools = tools;
- inputs.add_generation_prompt = add_generation_prompt;
- inputs.extra_context = extra_context;
- inputs.now = std::chrono::system_clock::now();
-
- chat_template_options opts;
- opts.apply_polyfills = apply_polyfills;
-
- return apply(inputs, opts);
- }
-
- std::string apply(
- const chat_template_inputs & inputs,
- const chat_template_options & opts = chat_template_options()) const
- {
- json actual_messages;
-
- auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
- auto has_tool_calls = false;
- auto has_tool_responses = false;
- auto has_string_content = false;
- for (const auto & message : inputs.messages) {
- if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
- has_tool_calls = true;
- }
- if (message.contains("role") && message["role"] == "tool") {
- has_tool_responses = true;
- }
- if (message.contains("content") && message["content"].is_string()) {
- has_string_content = true;
- }
- }
-
- auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
- auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
- auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
- auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
- auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
- auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
- auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
-
- auto needs_polyfills = opts.apply_polyfills && (false
- || polyfill_system_role
- || polyfill_tools
- || polyfill_tool_calls
- || polyfill_tool_responses
- || polyfill_object_arguments
- || polyfill_typed_content
- );
-
- if (needs_polyfills) {
- actual_messages = json::array();
-
- auto add_message = [&](const json & msg) {
- if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
- actual_messages.push_back({
- {"role", msg.at("role")},
- {"content", {{
- {"type", "text"},
- {"text", msg.at("content")},
- }}},
- });
- } else {
- actual_messages.push_back(msg);
- }
- };
-
- std::string pending_system;
- auto flush_sys = [&]() {
- if (!pending_system.empty()) {
- add_message({
- {"role", "user"},
- {"content", pending_system},
- });
- pending_system.clear();
- }
- };
-
- json adjusted_messages;
- if (polyfill_tools) {
- adjusted_messages = add_system(inputs.messages,
- "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
- (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
- } else {
- adjusted_messages = inputs.messages;
- }
-
- for (const auto & message_ : adjusted_messages) {
- auto message = message_;
- if (!message.contains("role") || !message.contains("content")) {
- throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
- }
- std::string role = message.at("role");
-
- if (message.contains("tool_calls")) {
- if (polyfill_object_arguments || polyfill_tool_calls) {
- for (auto & tool_call : message.at("tool_calls")) {
- if (tool_call["type"] == "function") {
- auto & function = tool_call.at("function");
- auto & arguments = function.at("arguments");
- if (arguments.is_string()) {
- try {
- arguments = json::parse(arguments.get<std::string>());
- } catch (const std::exception & ecvt) {
- fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
- }
- }
- }
- }
- }
- if (polyfill_tool_calls) {
- auto content = message.at("content");
- auto tool_calls = json::array();
- for (const auto & tool_call : message.at("tool_calls")) {
- if (tool_call.at("type") != "function") {
- continue;
- }
- const auto & function = tool_call.at("function");
- auto tc = json {
- {"name", function.at("name")},
- {"arguments", function.at("arguments")},
- };
- if (tool_call.contains("id")) {
- tc["id"] = tool_call["id"];
- }
- tool_calls.push_back(tc);
- }
- auto obj = json {
- {"tool_calls", tool_calls},
- };
- if (!content.is_null() && content != "") {
- obj["content"] = content;
- }
- message["content"] = obj.dump(2);
- message.erase("tool_calls");
- }
- }
- if (polyfill_tool_responses && role == "tool") {
- message["role"] = "user";
- auto obj = json {
- {"tool_response", {
- {"content", message.at("content")},
- }},
- };
- if (message.contains("name")) {
- obj["tool_response"]["name"] = message.at("name");
- }
- if (message.contains("tool_call_id")) {
- obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
- }
- message["content"] = obj.dump(2);
- message.erase("name");
- }
-
- if (!message["content"].is_null() && polyfill_system_role) {
- std::string content = message.at("content");
- if (role == "system") {
- if (!pending_system.empty()) pending_system += "\n";
- pending_system += content;
- continue;
- } else {
- if (role == "user") {
- if (!pending_system.empty()) {
- message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
- pending_system.clear();
- }
- } else {
- flush_sys();
- }
- }
- }
- add_message(message);
- }
- flush_sys();
- } else {
- actual_messages = inputs.messages;
- }
-
- auto context = minja::Context::make(json({
- {"messages", actual_messages},
- {"add_generation_prompt", inputs.add_generation_prompt},
- }));
- context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
- context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
- if (opts.define_strftime_now) {
- auto now = inputs.now;
- context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
- args.expectArgs("strftime_now", {1, 1}, {0, 0});
- auto format = args.args[0].get<std::string>();
-
- auto time = std::chrono::system_clock::to_time_t(now);
- auto local_time = *std::localtime(&time);
- std::ostringstream ss;
- ss << std::put_time(&local_time, format.c_str());
- return ss.str();
- }));
- }
- if (!inputs.tools.is_null()) {
- context->set("tools", minja::Value(inputs.tools));
- }
- if (!inputs.extra_context.is_null()) {
- for (auto & kv : inputs.extra_context.items()) {
- context->set(kv.key(), minja::Value(kv.value()));
- }
- }
-
- auto ret = template_root_->render(context);
- // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
- // fprintf(stderr, "apply: %s\n\n", ret.c_str());
- return ret;
- }
-
- static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
- json messages_with_system = messages;
-
- if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
- std::string existing_system = messages_with_system.at(0).at("content");
- messages_with_system[0] = json {
- {"role", "system"},
- {"content", existing_system + "\n\n" + system_prompt},
- };
- } else {
- messages_with_system.insert(messages_with_system.begin(), json {
- {"role", "system"},
- {"content", system_prompt},
- });
- }
- return messages_with_system;
- }
-};
-
-} // namespace minja
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
#include "json-schema-to-grammar.h"
#include "log.h"
-#include "minja.hpp"
+#include "minja/chat-template.hpp"
+#include "minja/minja.hpp"
+
+#include <optional>
+
+typedef minja::chat_template common_chat_template;
+
+struct common_chat_templates {
+ bool has_explicit_template; // Model had builtin template or template overridde was specified.
+ std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
+ std::unique_ptr<common_chat_template> template_tool_use;
+};
+
+struct templates_params {
+ json messages;
+ json tools;
+ common_chat_tool_choice tool_choice;
+ json json_schema;
+ bool parallel_tool_calls;
+ bool stream;
+ std::string grammar;
+ bool add_generation_prompt = true;
+ bool extract_reasoning = true;
+};
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
+ if (tool_choice == "auto") {
+ return COMMON_CHAT_TOOL_CHOICE_AUTO;
+ }
+ if (tool_choice == "none") {
+ return COMMON_CHAT_TOOL_CHOICE_NONE;
+ }
+ if (tool_choice == "required") {
+ return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+ }
+ throw std::runtime_error("Invalid tool_choice: " + tool_choice);
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
+ std::vector<common_chat_msg> msgs;
+
+ try {
+
+ if (!messages.is_array()) {
+ throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
+ }
+
+ for (const auto & message : messages) {
+ if (!message.is_object()) {
+ throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
+ }
+
+ common_chat_msg msg;
+ if (!message.contains("role")) {
+ throw std::runtime_error("Missing 'role' in message: " + message.dump());
+ }
+ msg.role = message.at("role");
+
+ if (message.contains("content")) {
+ const auto & content = message.at("content");
+ if (content.is_string()) {
+ msg.content = content;
+ } else if (content.is_array()) {
+ for (const auto & part : content) {
+ if (!part.contains("type")) {
+ throw std::runtime_error("Missing content part type: " + part.dump());
+ }
+ const auto & type = part.at("type");
+ if (type != "text") {
+ throw std::runtime_error("Unsupported content part type: " + type.dump());
+ }
+ common_chat_msg_content_part msg_part;
+ msg_part.type = type;
+ msg_part.text = part.at("text");
+ msg.content_parts.push_back(msg_part);
+ }
+ } else if (!content.is_null()) {
+ throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+ }
+ } else {
+ throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+ }
+ if (message.contains("reasoning_content")) {
+ msg.reasoning_content = message.at("reasoning_content");
+ }
+ if (message.contains("name")) {
+ msg.tool_name = message.at("name");
+ }
+ if (message.contains("tool_call_id")) {
+ msg.tool_call_id = message.at("tool_call_id");
+ }
+ if (message.contains("tool_calls")) {
+ for (const auto & tool_call : message.at("tool_calls")) {
+ common_chat_tool_call tc;
+ if (!tool_call.contains("type")) {
+ throw std::runtime_error("Missing tool call type: " + tool_call.dump());
+ }
+ const auto & type = tool_call.at("type");
+ if (type != "function") {
+ throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
+ }
+ if (!tool_call.contains("function")) {
+ throw std::runtime_error("Missing tool call function: " + tool_call.dump());
+ }
+ const auto & fc = tool_call.at("function");
+ if (!fc.contains("name")) {
+ throw std::runtime_error("Missing tool call name: " + tool_call.dump());
+ }
+ tc.name = fc.at("name");
+ tc.arguments = fc.at("arguments");
+ if (tool_call.contains("id")) {
+ tc.id = tool_call.at("id");
+ }
+ msg.tool_calls.push_back(tc);
+ }
+ }
+
+ msgs.push_back(msg);
+ }
+ } catch (const std::exception & e) {
+ throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
+ }
+
+ return msgs;
+}
+
+template <>
+json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
+ json messages = json::array();
+ for (const auto & msg : msgs) {
+ if (!msg.content.empty() && !msg.content_parts.empty()) {
+ throw std::runtime_error("Cannot specify both content and content_parts");
+ }
+ json jmsg {
+ {"role", msg.role},
+ };
+ if (!msg.content.empty()) {
+ jmsg["content"] = msg.content;
+ } else if (!msg.content_parts.empty()) {
+ if (concat_typed_text) {
+ std::string text;
+ for (const auto & part : msg.content_parts) {
+ if (part.type != "text") {
+ LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
+ continue;
+ }
+ if (!text.empty()) {
+ text += '\n';
+ }
+ text += part.text;
+ }
+ jmsg["content"] = text;
+ } else {
+ auto & parts = jmsg["content"] = json::array();
+ for (const auto & part : msg.content_parts) {
+ parts.push_back({
+ {"type", part.type},
+ {"text", part.text},
+ });
+ }
+ }
+ } else {
+ jmsg["content"] = json(); // null
+ }
+ if (!msg.reasoning_content.empty()) {
+ jmsg["reasoning_content"] = msg.reasoning_content;
+ }
+ if (!msg.tool_name.empty()) {
+ jmsg["name"] = msg.tool_name;
+ }
+ if (!msg.tool_call_id.empty()) {
+ jmsg["tool_call_id"] = msg.tool_call_id;
+ }
+ if (!msg.tool_calls.empty()) {
+ auto & tool_calls = jmsg["tool_calls"] = json::array();
+ for (const auto & tool_call : msg.tool_calls) {
+ json tc {
+ {"type", "function"},
+ {"function", {
+ {"name", tool_call.name},
+ {"arguments", tool_call.arguments},
+ }},
+ };
+ if (!tool_call.id.empty()) {
+ tc["id"] = tool_call.id;
+ }
+ tool_calls.push_back(tc);
+ }
+ }
+ messages.push_back(jmsg);
+ }
+ return messages;
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
+ return common_chat_msgs_parse_oaicompat(json::parse(messages));
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
+ std::vector<common_chat_tool> result;
+
+ try {
+ if (!tools.is_null()) {
+ if (!tools.is_array()) {
+ throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
+ }
+ for (const auto & tool : tools) {
+ if (!tool.contains("type")) {
+ throw std::runtime_error("Missing tool type: " + tool.dump());
+ }
+ const auto & type = tool.at("type");
+ if (!type.is_string() || type != "function") {
+ throw std::runtime_error("Unsupported tool type: " + tool.dump());
+ }
+ if (!tool.contains("function")) {
+ throw std::runtime_error("Missing tool function: " + tool.dump());
+ }
+
+ const auto & function = tool.at("function");
+ result.push_back({
+ /* .name = */ function.at("name"),
+ /* .description = */ function.at("description"),
+ /* .parameters = */ function.at("parameters").dump(),
+ });
+ }
+ }
+ } catch (const std::exception & e) {
+ throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
+ }
+
+ return result;
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
+ return common_chat_tools_parse_oaicompat(json::parse(tools));
+}
+
+template <>
+json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
+ if (tools.empty()) {
+ return json();
+ }
+
+ auto result = json::array();
+ for (const auto & tool : tools) {
+ result.push_back({
+ {"type", "function"},
+ {"function", {
+ {"name", tool.name},
+ {"description", tool.description},
+ {"parameters", json::parse(tool.parameters)},
+ }},
+ });
+ }
+ return result;
+}
+
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
+ if (use_jinja) {
+ try {
+ common_chat_msg msg;
+ msg.role = "user";
+ msg.content = "test";
+
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
+
+ common_chat_templates_inputs inputs;
+ inputs.messages = {msg};
+
+ common_chat_templates_apply(tmpls.get(), inputs);
+ return true;
+ } catch (const std::exception & e) {
+ LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
+ return false;
+ }
+ }
+ llama_chat_message chat[] = {{"user", "test"}};
+ const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
+ return res >= 0;
+}
+
+std::string common_chat_format_single(
+ const struct common_chat_templates * tmpls,
+ const std::vector<common_chat_msg> & past_msg,
+ const common_chat_msg & new_msg,
+ bool add_ass,
+ bool use_jinja) {
+
+ common_chat_templates_inputs inputs;
+ inputs.use_jinja = use_jinja;
+
+ std::string fmt_past_msg;
+ if (!past_msg.empty()) {
+ inputs.messages = past_msg;
+ inputs.add_generation_prompt = false;
+ fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+ }
+ std::ostringstream ss;
+ // if the past_msg ends with a newline, we must preserve it in the formatted version
+ if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+ ss << "\n";
+ };
+ // format chat with new_msg
+ inputs.messages.push_back(new_msg);
+ inputs.add_generation_prompt = add_ass;
+ auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+ // get the diff part
+ ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+ return ss.str();
+}
+
+std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
+ common_chat_templates_inputs inputs;
+ inputs.use_jinja = use_jinja;
+ auto add_simple_msg = [&](auto role, auto content) {
+ common_chat_msg msg;
+ msg.role = role;
+ msg.content = content;
+ inputs.messages.push_back(msg);
+ };
+ add_simple_msg("system", "You are a helpful assistant");
+ add_simple_msg("user", "Hello");
+ add_simple_msg("assistant", "Hi there");
+ add_simple_msg("user", "How are you?");
+ return common_chat_templates_apply(tmpls, inputs).prompt;
+}
+
+#define CHATML_TEMPLATE_SRC \
+ "{%- for message in messages -%}\n" \
+ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
+ "{%- endfor -%}\n" \
+ "{%- if add_generation_prompt -%}\n" \
+ " {{- '<|im_start|>assistant\n' -}}\n" \
+ "{%- endif -%}"
+
+void common_chat_templates_free(struct common_chat_templates * tmpls) {
+ delete tmpls;
+}
+
+bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
+ return tmpls->has_explicit_template;
+}
+
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
+ if (variant != nullptr) {
+ if (strcmp(variant, "tool_use") == 0) {
+ if (tmpls->template_tool_use) {
+ return tmpls->template_tool_use->source().c_str();
+ }
+ return nullptr;
+ } else {
+ LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
+ }
+ }
+ return tmpls->template_default->source().c_str();
+}
+
+common_chat_templates_ptr common_chat_templates_init(
+ const struct llama_model * model,
+ const std::string & chat_template_override,
+ const std::string & bos_token_override,
+ const std::string & eos_token_override)
+{
+ std::string default_template_src;
+ std::string template_tool_use_src;
+
+ bool has_explicit_template = !chat_template_override.empty();
+ if (chat_template_override.empty()) {
+ GGML_ASSERT(model != nullptr);
+ const auto * str = llama_model_chat_template(model, /* name */ nullptr);
+ if (str) {
+ default_template_src = str;
+ has_explicit_template = true;
+ }
+ str = llama_model_chat_template(model, /* name */ "tool_use");
+ if (str) {
+ template_tool_use_src = str;
+ has_explicit_template = true;
+ }
+ } else {
+ default_template_src = chat_template_override;
+ }
+ if (default_template_src.empty() || default_template_src == "chatml") {
+ if (!template_tool_use_src.empty()) {
+ default_template_src = template_tool_use_src;
+ } else {
+ default_template_src = CHATML_TEMPLATE_SRC;
+ }
+ }
+ std::string token_bos = bos_token_override;
+ std::string token_eos = eos_token_override;
+ if (model) {
+ const auto * vocab = llama_model_get_vocab(model);
+ const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
+ if (token == LLAMA_TOKEN_NULL) {
+ if (default_template_src.find(jinja_variable_name) != std::string::npos
+ || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
+ LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
+ }
+ return std::string();
+ }
+ return common_token_to_piece(vocab, token, true);
+ };
+ token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
+ token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
+ }
+ common_chat_templates_ptr tmpls(new common_chat_templates());
+ tmpls->has_explicit_template = has_explicit_template;
+ try {
+ tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
+ } catch (const std::exception & e) {
+ LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
+ tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
+ }
+ if (!template_tool_use_src.empty()) {
+ try {
+ tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
+ } catch (const std::exception & e) {
+ LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
+ }
+ }
+ return tmpls;
+}
std::string common_chat_format_name(common_chat_format format) {
switch (format) {
json_error_locator() : position(0), found_error(false) {}
- bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
+ bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
this->position = position - 1;
this->found_error = true;
return false;
}
- bool null() override { return true; }
- bool boolean(bool) override { return true; }
- bool number_integer(number_integer_t) override { return true; }
- bool number_unsigned(number_unsigned_t) override { return true; }
- bool number_float(number_float_t, const string_t &) override { return true; }
- bool string(string_t &) override { return true; }
- bool binary(binary_t &) override { return true; }
- bool start_object(std::size_t) override { return true; }
- bool key(string_t &) override { return true; }
+ bool null() override { return true; } // NOLINT
+ bool boolean(bool) override { return true; } // NOLINT
+ bool number_integer(number_integer_t) override { return true; } // NOLINT
+ bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
+ bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
+ bool string(string_t &) override { return true; } // NOLINT
+ bool binary(binary_t &) override { return true; } // NOLINT
+ bool start_object(std::size_t) override { return true; } // NOLINT
+ bool key(string_t &) override { return true; } // NOLINT
bool end_object() override { return true; }
- bool start_array(std::size_t) override { return true; }
+ bool start_array(std::size_t) override { return true; } // NOLINT
bool end_array() override { return true; }
};
json_error_locator err_loc;
// tmpl_inputs.now = std::chrono::system_clock::now();
minja::chat_template_options tmpl_opts;
- tmpl_opts.use_bos_token = false;
- tmpl_opts.use_eos_token = false;
-
- return tmpl.apply(tmpl_inputs, tmpl_opts);
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
+ // instead of using `chat_template_options.use_bos_token = false`, since these tokens
+ // may be needed inside the template / between messages too.
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
+ if (string_starts_with(result, tmpl.bos_token())) {
+ result = result.substr(tmpl.bos_token().size());
+ }
+ if (string_ends_with(result, tmpl.eos_token())) {
+ result = result.substr(0, result.size() - tmpl.eos_token().size());
+ }
+ return result;
}
-static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto tool_call_schemas = json::array();
{"required", json::array({"tool_call"})},
};
const auto schema =
- inputs.tool_choice != "required"
+ inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
? json {
{"anyOf", json::array({
tool_call,
return result;
}
-static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
-static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) {
- throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
+ throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
}
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
- throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
+ throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
}
}
if (parameters_properties.size() != expected_properties.size()) {
}
}
-static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
+static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
- if (name == "wolfram_alpha") {
+ if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
- expect_tool_parameters(name, parameters, {"query"});
- } else if (name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
- kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
+ kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
- return {
- /* .role = */ "assistant",
- /* .content = */ match.prefix().str(),
- /* .tool_calls = */ {
- {
- /* .name = */ match[1],
- /* .arguments = */ (json {
- {arg_name, arg_value},
- }).dump(),
- /* .id = */ "",
- },
- },
- };
+ common_chat_msg msg;
+ msg.role = "assistant";
+ msg.content = match.prefix().str();
+ msg.tool_calls.push_back({
+ /* .name = */ name,
+ /* .arguments = */ (json {
+ {arg_name, arg_value},
+ }).dump(),
+ /* .id = */ "",
+ });
+ return msg;
}
}
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
-static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
- data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
return msg;
}
-static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
- fprintf(stderr, "%s\n", __func__);
+static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ LOG_DBG("%s\n", __func__);
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
});
if (inputs.tools.is_array() && !inputs.tools.empty()) {
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
-static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
}
}
-static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
- auto type = parameters.at("type");
+ const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
- return {
- /* .role = */ "assistant",
- /* .content = */ match.prefix().str(),
- /* .tool_calls = */ {
- {
- /* .name = */ "python",
- /* .arguments = */ (json {{"code", code}}).dump(),
- /* .id = */ "",
- },
- }
- };
+ common_chat_msg msg;
+ msg.role = "assistant";
+ msg.content = match.prefix().str();
+ msg.tool_calls.push_back({
+ /* .name = */ "python",
+ /* .arguments = */ (json {{"code", code}}).dump(),
+ /* .id = */ "",
+ });
+ return msg;
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
-static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
- data.grammar_lazy = inputs.tool_choice != "required";
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
+ common_chat_msg msg;
+ msg.role = "assistant";
+
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
- return {
- /* .role = */ "assistant",
- /* .content = */ input,
- /* .tool_calls = */ {},
- };
+ msg.content = input;
+ return msg;
}
- common_chat_msg result;
- result.role = "assistant";
- result.content = rit->prefix();
+ msg.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call.at("arguments");
- result.tool_calls.push_back({
+ msg.tool_calls.push_back({
call.at("name"),
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
break;
}
}
- return result;
+ return msg;
} catch (const std::exception & e) {
- return {
- /* .role = */ "assistant",
- /* .content = */ input,
- /* .tool_calls = */ {},
- };
+ LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
+ common_chat_msg msg;
+ msg.role = "assistant";
+ msg.content = input;
+ return msg;
}
}
-static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
return data;
}
-common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_templates_apply_jinja(
+ const struct common_chat_templates * tmpls,
+ const struct common_chat_templates_inputs & inputs)
+{
+ templates_params params;
+ params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
+ const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
+ ? *tmpls->template_tool_use
+ : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
+ params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
+ params.add_generation_prompt = inputs.add_generation_prompt;
+ params.extract_reasoning = inputs.extract_reasoning;
+ params.tool_choice = inputs.tool_choice;
+ params.grammar = inputs.grammar;
+ if (!inputs.json_schema.empty()) {
+ params.json_schema = json::parse(inputs.json_schema);
+ }
- if (inputs.tools.is_array()) {
- if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
+ if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
+ LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
+ params.parallel_tool_calls = false;
+ } else {
+ params.parallel_tool_calls = inputs.parallel_tool_calls;
+ }
+
+ if (params.tools.is_array()) {
+ if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
}
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
- if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
- return common_chat_params_init_deepseek_r1(tmpl, inputs);
+ if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
+ return common_chat_params_init_deepseek_r1(tmpl, params);
}
// Command R7B: : use handler in all cases except json schema (thinking / tools).
- if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
- return common_chat_params_init_command_r7b(tmpl, inputs);
+ if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
+ return common_chat_params_init_command_r7b(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
- if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
- return common_chat_params_init_generic(tmpl, inputs);
+ if ((params.tools.is_array() && params.json_schema.is_object())) {
+ return common_chat_params_init_generic(tmpl, params);
}
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) {
- return common_chat_params_init_functionary_v3_2(tmpl, inputs);
+ return common_chat_params_init_functionary_v3_2(tmpl, params);
}
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) {
- return common_chat_params_init_firefunction_v2(tmpl, inputs);
+ return common_chat_params_init_firefunction_v2(tmpl, params);
}
// Plain handler (no tools)
- if (inputs.tools.is_null() || inputs.tool_choice == "none") {
- return common_chat_params_init_without_tools(tmpl, inputs);
+ if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
+ return common_chat_params_init_without_tools(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) {
- return common_chat_params_init_hermes_2_pro(tmpl, inputs);
+ return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
- return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
+ return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
}
// Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
- return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
+ return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
}
// Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) {
- return common_chat_params_init_mistral_nemo(tmpl, inputs);
+ return common_chat_params_init_mistral_nemo(tmpl, params);
}
// Generic fallback
- return common_chat_params_init_generic(tmpl, inputs);
+ return common_chat_params_init_generic(tmpl, params);
+}
+
+// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
+static common_chat_params common_chat_templates_apply_legacy(
+ const struct common_chat_templates * tmpls,
+ const struct common_chat_templates_inputs & inputs)
+{
+ int alloc_size = 0;
+ std::vector<llama_chat_message> chat;
+ std::vector<std::string> contents;
+ for (const auto & msg : inputs.messages) {
+ auto content = msg.content;
+ for (const auto & part : msg.content_parts) {
+ if (part.type != "text") {
+ LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
+ continue;
+ }
+ if (!content.empty()) {
+ content += "\n";;
+ }
+ content += part.text;
+ }
+ contents.emplace_back(std::move(content));
+ }
+ for (size_t i = 0; i < contents.size(); ++i) {
+ const auto & msg = inputs.messages[i];
+ const auto & content = contents[i];
+ chat.push_back({msg.role.c_str(), content.c_str()});
+ alloc_size += (msg.role.size() + content.size()) * 1.25;
+ }
+
+ std::vector<char> buf(alloc_size);
+
+ // run the first time to get the total output length
+ const auto & src = tmpls->template_default->source();
+ int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+
+ // error: chat template is not supported
+ if (res < 0) {
+ // if the custom "tmpl" is not supported, we throw an error
+ // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+ throw std::runtime_error("this custom template is not supported");
+ }
+
+ // if it turns out that our buffer is too small, we resize it
+ if ((size_t) res > buf.size()) {
+ buf.resize(res);
+ res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+ }
+
+ common_chat_params params;
+ params.prompt = std::string(buf.data(), res);
+ if (!inputs.json_schema.empty()) {
+ params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
+ } else {
+ params.grammar = inputs.grammar;
+ }
+ return params;
+}
+
+common_chat_params common_chat_templates_apply(
+ const struct common_chat_templates * tmpls,
+ const struct common_chat_templates_inputs & inputs)
+{
+ GGML_ASSERT(tmpls != nullptr);
+ return inputs.use_jinja
+ ? common_chat_templates_apply_jinja(tmpls, inputs)
+ : common_chat_templates_apply_legacy(tmpls, inputs);
}
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
- return {
- /* .role = */ "assistant",
- /* .content = */ input,
- /* .tool_calls = */ {},
- };
+ common_chat_msg msg;
+ msg.role = "assistant";
+ msg.content = input;
+ return msg;
}
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
--- /dev/null
+// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
+
+#pragma once
+
+#include "common.h"
+#include <string>
+#include <vector>
+
+struct common_chat_templates;
+
+struct common_chat_tool_call {
+ std::string name;
+ std::string arguments;
+ std::string id;
+};
+
+struct common_chat_msg_content_part {
+ std::string type;
+ std::string text;
+};
+
+struct common_chat_msg {
+ std::string role;
+ std::string content;
+ std::vector<common_chat_msg_content_part> content_parts = {};
+ std::vector<common_chat_tool_call> tool_calls = {};
+ std::string reasoning_content;
+ std::string tool_name;
+ std::string tool_call_id;
+};
+
+struct common_chat_tool {
+ std::string name;
+ std::string description;
+ std::string parameters;
+};
+
+enum common_chat_tool_choice {
+ COMMON_CHAT_TOOL_CHOICE_AUTO,
+ COMMON_CHAT_TOOL_CHOICE_REQUIRED,
+ COMMON_CHAT_TOOL_CHOICE_NONE,
+};
+
+enum common_chat_format {
+ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+ COMMON_CHAT_FORMAT_GENERIC,
+ COMMON_CHAT_FORMAT_MISTRAL_NEMO,
+ COMMON_CHAT_FORMAT_LLAMA_3_X,
+ COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
+ COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
+ COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
+ COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+ COMMON_CHAT_FORMAT_COMMAND_R7B,
+ COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
+
+ COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
+};
+
+struct common_chat_templates_inputs {
+ std::vector<common_chat_msg> messages;
+ std::string grammar;
+ std::string json_schema;
+ bool add_generation_prompt = true;
+ bool use_jinja = true;
+ // Parameters below only supported when use_jinja is true
+ std::vector<common_chat_tool> tools;
+ common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
+ bool parallel_tool_calls = false;
+ bool extract_reasoning = true;
+};
+
+struct common_chat_params {
+ common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ std::string prompt;
+ std::string grammar;
+ bool grammar_lazy = false;
+ std::vector<common_grammar_trigger> grammar_triggers;
+ std::vector<std::string> preserved_tokens;
+ std::vector<std::string> additional_stops;
+};
+
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
+
+void common_chat_templates_free(struct common_chat_templates * tmpls);
+
+struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
+
+typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
+
+common_chat_templates_ptr common_chat_templates_init(
+ const struct llama_model * model,
+ const std::string & chat_template_override,
+ const std::string & bos_token_override = "",
+ const std::string & eos_token_override = "");
+
+bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
+
+
+struct common_chat_params common_chat_templates_apply(
+ const struct common_chat_templates * tmpls,
+ const struct common_chat_templates_inputs & inputs);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string common_chat_format_single(
+ const struct common_chat_templates * tmpls,
+ const std::vector<common_chat_msg> & past_msg,
+ const common_chat_msg & new_msg,
+ bool add_ass,
+ bool use_jinja);
+
+// Returns an example of formatted chat
+std::string common_chat_format_example(
+ const struct common_chat_templates * tmpls,
+ bool use_jinja);
+
+std::string common_chat_format_name(common_chat_format format);
+common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
+
+// Parses a JSON array of messages in OpenAI's chat completion API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
+template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
+
+// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
+template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
+++ /dev/null
-// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
-
-#pragma once
-
-#include "common.h"
-#include <json.hpp>
-#include <optional>
-#include <string>
-#include <vector>
-
-using json = nlohmann::ordered_json;
-
-struct common_chat_inputs {
- json messages;
- json tools;
- json tool_choice;
- json json_schema;
- bool parallel_tool_calls;
- bool stream;
- std::string grammar;
- bool add_generation_prompt = true;
- bool extract_reasoning = true;
-};
-
-enum common_chat_format {
- COMMON_CHAT_FORMAT_CONTENT_ONLY,
- COMMON_CHAT_FORMAT_GENERIC,
- COMMON_CHAT_FORMAT_MISTRAL_NEMO,
- COMMON_CHAT_FORMAT_LLAMA_3_X,
- COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
- COMMON_CHAT_FORMAT_DEEPSEEK_R1,
- COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
- COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
- COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
- COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
- COMMON_CHAT_FORMAT_HERMES_2_PRO,
- COMMON_CHAT_FORMAT_COMMAND_R7B,
- COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
-
- COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
-};
-
-struct common_chat_params {
- common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- json prompt;
- std::string grammar;
- bool grammar_lazy = false;
- std::vector<common_grammar_trigger> grammar_triggers;
- std::vector<std::string> preserved_tokens;
- std::vector<std::string> additional_stops;
-};
-
-struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
-std::string common_chat_format_name(common_chat_format format);
-common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
-#include "chat.hpp"
-#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
return text;
}
-//
-// Chat template utils
-//
-
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
- if (use_jinja) {
- try {
- auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
- common_chat_inputs inputs;
- inputs.messages = json::array({{
- {"role", "user"},
- {"content", "test"},
- }});
- common_chat_params_init(chat_template, inputs);
- return true;
- } catch (const std::exception & e) {
- LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
- return false;
- }
- }
- llama_chat_message chat[] = {{"user", "test"}};
- const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
- return res >= 0;
-}
-
-std::string common_chat_apply_template(
- const common_chat_template & tmpl,
- const std::vector<common_chat_msg> & msgs,
- bool add_ass,
- bool use_jinja) {
- if (use_jinja) {
- auto messages = json::array();
- for (const auto & msg : msgs) {
- messages.push_back({{"role", msg.role}, {"content", msg.content}});
- }
- common_chat_inputs inputs;
- inputs.messages = messages;
- inputs.add_generation_prompt = add_ass;
- return common_chat_params_init(tmpl, inputs).prompt;
- }
-
- int alloc_size = 0;
- std::vector<llama_chat_message> chat;
- for (const auto & msg : msgs) {
- chat.push_back({msg.role.c_str(), msg.content.c_str()});
- alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
- }
-
- std::vector<char> buf(alloc_size);
-
- // run the first time to get the total output length
- int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-
- // error: chat template is not supported
- if (res < 0) {
- // if the custom "tmpl" is not supported, we throw an error
- // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
- throw std::runtime_error("this custom template is not supported");
- }
-
- // if it turns out that our buffer is too small, we resize it
- if ((size_t) res > buf.size()) {
- buf.resize(res);
- res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
- }
-
- std::string formatted_chat(buf.data(), res);
- return formatted_chat;
-}
-
-std::string common_chat_format_single(
- const common_chat_template & tmpl,
- const std::vector<common_chat_msg> & past_msg,
- const common_chat_msg & new_msg,
- bool add_ass,
- bool use_jinja) {
- std::ostringstream ss;
- auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
- std::vector<common_chat_msg> chat_new(past_msg);
- // if the past_msg ends with a newline, we must preserve it in the formatted version
- if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
- ss << "\n";
- };
- // format chat with new_msg
- chat_new.push_back(new_msg);
- auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
- // get the diff part
- ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
- return ss.str();
-}
-
-std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
- std::vector<common_chat_msg> msgs = {
- {"system", "You are a helpful assistant", {}},
- {"user", "Hello", {}},
- {"assistant", "Hi there", {}},
- {"user", "How are you?", {}},
- };
- return common_chat_apply_template(tmpl, msgs, true, use_jinja);
-}
-
-#define CHATML_TEMPLATE_SRC \
- "{%- for message in messages -%}\n" \
- " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
- "{%- endfor -%}\n" \
- "{%- if add_generation_prompt -%}\n" \
- " {{- '<|im_start|>assistant\n' -}}\n" \
- "{%- endif -%}"
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
-{
- std::string default_template_src;
- std::string template_tool_use_src;
-
- bool has_explicit_template = !chat_template_override.empty();
- if (chat_template_override.empty()) {
- auto str = llama_model_chat_template(model, /* name */ nullptr);
- if (str) {
- default_template_src = str;
- has_explicit_template = true;
- }
- str = llama_model_chat_template(model, /* name */ "tool_use");
- if (str) {
- template_tool_use_src = str;
- has_explicit_template = true;
- }
- } else {
- default_template_src = chat_template_override;
- }
- if (default_template_src.empty() || default_template_src == "chatml") {
- if (!template_tool_use_src.empty()) {
- default_template_src = template_tool_use_src;
- } else {
- default_template_src = CHATML_TEMPLATE_SRC;
- }
- }
- auto vocab = llama_model_get_vocab(model);
- const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
- if (token == LLAMA_TOKEN_NULL) {
- if (default_template_src.find(jinja_variable_name) != std::string::npos
- || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
- LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
- }
- return std::string();
- } else {
- return common_token_to_piece(vocab, token, true);
- }
- };
- auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
- auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
- try {
- return {
- has_explicit_template,
- std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
- template_tool_use_src.empty()
- ? nullptr
- : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
- };
- } catch (const std::exception & e) {
- LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
- return {
- has_explicit_template,
- std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
- nullptr,
- };
- }
-}
-
//
// KV cache utils
//
const std::vector<llama_token> & tokens,
bool special = true);
-//
-// Chat template utils
-//
-
-struct common_tool_call {
- std::string name;
- std::string arguments;
- std::string id;
-};
-
-// same with llama_chat_message, but uses std::string
-struct common_chat_msg {
- std::string role;
- std::string content;
- std::vector<common_tool_call> tool_calls;
- std::string reasoning_content = "";
-};
-
-// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
-
-namespace minja {
- class chat_template;
-}
-
-typedef minja::chat_template common_chat_template;
-
-struct common_chat_templates {
- bool has_explicit_template; // Model had builtin template or template overridde was specified.
- std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
- std::unique_ptr<common_chat_template> template_tool_use;
-};
-
-// CPP wrapper for llama_chat_apply_template
-// If the built-in template is not supported, we default to chatml
-// If the custom "tmpl" is not supported, we throw an error
-std::string common_chat_apply_template(
- const common_chat_template & tmpl,
- const std::vector<common_chat_msg> & chat,
- bool add_ass,
- bool use_jinja);
-
-// Format single message, while taking into account the position of that message in chat history
-std::string common_chat_format_single(
- const common_chat_template & tmpl,
- const std::vector<common_chat_msg> & past_msg,
- const common_chat_msg & new_msg,
- bool add_ass,
- bool use_jinja);
-
-// Returns an example of formatted chat
-std::string common_chat_format_example(
- const common_chat_template & tmpl, bool use_jinja);
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
-
//
// KV cache utils
//
+++ /dev/null
-/*
- Copyright 2024 Google LLC
-
- Use of this source code is governed by an MIT-style
- license that can be found in the LICENSE file or at
- https://opensource.org/licenses/MIT.
-*/
-// SPDX-License-Identifier: MIT
-#pragma once
-
-#include <iostream>
-#include <string>
-#include <vector>
-#include <regex>
-#include <memory>
-#include <stdexcept>
-#include <sstream>
-#include <unordered_set>
-#include <json.hpp>
-
-using json = nlohmann::ordered_json;
-
-namespace minja {
-
-class Context;
-
-struct Options {
- bool trim_blocks; // removes the first newline after a block
- bool lstrip_blocks; // removes leading whitespace on the line of the block
- bool keep_trailing_newline; // don't remove last newline
-};
-
-struct ArgumentsValue;
-
-inline std::string normalize_newlines(const std::string & s) {
-#ifdef _WIN32
- static const std::regex nl_regex("\r\n");
- return std::regex_replace(s, nl_regex, "\n");
-#else
- return s;
-#endif
-}
-
-/* Values that behave roughly like in Python. */
-class Value : public std::enable_shared_from_this<Value> {
-public:
- using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
- using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
-
-private:
- using ObjectType = nlohmann::ordered_map<json, Value>; // Only contains primitive keys
- using ArrayType = std::vector<Value>;
-
- std::shared_ptr<ArrayType> array_;
- std::shared_ptr<ObjectType> object_;
- std::shared_ptr<CallableType> callable_;
- json primitive_;
-
- Value(const std::shared_ptr<ArrayType> & array) : array_(array) {}
- Value(const std::shared_ptr<ObjectType> & object) : object_(object) {}
- Value(const std::shared_ptr<CallableType> & callable) : object_(std::make_shared<ObjectType>()), callable_(callable) {}
-
- /* Python-style string repr */
- static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') {
- if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump());
- auto s = primitive.dump();
- if (string_quote == '"' || s.find('\'') != std::string::npos) {
- out << s;
- return;
- }
- // Reuse json dump, just changing string quotes
- out << string_quote;
- for (size_t i = 1, n = s.size() - 1; i < n; ++i) {
- if (s[i] == '\\' && s[i + 1] == '"') {
- out << '"';
- i++;
- } else if (s[i] == string_quote) {
- out << '\\' << string_quote;
- } else {
- out << s[i];
- }
- }
- out << string_quote;
- }
- void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
- auto print_indent = [&](int level) {
- if (indent > 0) {
- out << "\n";
- for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
- }
- };
- auto print_sub_sep = [&]() {
- out << ',';
- if (indent < 0) out << ' ';
- else print_indent(level + 1);
- };
-
- auto string_quote = to_json ? '"' : '\'';
-
- if (is_null()) out << "null";
- else if (array_) {
- out << "[";
- print_indent(level + 1);
- for (size_t i = 0; i < array_->size(); ++i) {
- if (i) print_sub_sep();
- (*array_)[i].dump(out, indent, level + 1, to_json);
- }
- print_indent(level);
- out << "]";
- } else if (object_) {
- out << "{";
- print_indent(level + 1);
- for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) {
- if (it != begin) print_sub_sep();
- if (it->first.is_string()) {
- dump_string(it->first, out, string_quote);
- } else {
- out << string_quote << it->first.dump() << string_quote;
- }
- out << ": ";
- it->second.dump(out, indent, level + 1, to_json);
- }
- print_indent(level);
- out << "}";
- } else if (callable_) {
- throw std::runtime_error("Cannot dump callable to JSON");
- } else if (is_boolean() && !to_json) {
- out << (this->to_bool() ? "True" : "False");
- } else if (is_string() && !to_json) {
- dump_string(primitive_, out, string_quote);
- } else {
- out << primitive_.dump();
- }
- }
-
-public:
- Value() {}
- Value(const bool& v) : primitive_(v) {}
- Value(const int64_t & v) : primitive_(v) {}
- Value(const double& v) : primitive_(v) {}
- Value(const std::nullptr_t &) {}
- Value(const std::string & v) : primitive_(v) {}
- Value(const char * v) : primitive_(std::string(v)) {}
-
- Value(const json & v) {
- if (v.is_object()) {
- auto object = std::make_shared<ObjectType>();
- for (auto it = v.begin(); it != v.end(); ++it) {
- (*object)[it.key()] = it.value();
- }
- object_ = std::move(object);
- } else if (v.is_array()) {
- auto array = std::make_shared<ArrayType>();
- for (const auto& item : v) {
- array->push_back(Value(item));
- }
- array_ = array;
- } else {
- primitive_ = v;
- }
- }
-
- std::vector<Value> keys() {
- if (!object_) throw std::runtime_error("Value is not an object: " + dump());
- std::vector<Value> res;
- for (const auto& item : *object_) {
- res.push_back(item.first);
- }
- return res;
- }
-
- size_t size() const {
- if (is_object()) return object_->size();
- if (is_array()) return array_->size();
- if (is_string()) return primitive_.get<std::string>().length();
- throw std::runtime_error("Value is not an array or object: " + dump());
- }
-
- static Value array(const std::vector<Value> values = {}) {
- auto array = std::make_shared<ArrayType>();
- for (const auto& item : values) {
- array->push_back(item);
- }
- return Value(array);
- }
- static Value object(const std::shared_ptr<ObjectType> object = std::make_shared<ObjectType>()) {
- return Value(object);
- }
- static Value callable(const CallableType & callable) {
- return Value(std::make_shared<CallableType>(callable));
- }
-
- void insert(size_t index, const Value& v) {
- if (!array_)
- throw std::runtime_error("Value is not an array: " + dump());
- array_->insert(array_->begin() + index, v);
- }
- void push_back(const Value& v) {
- if (!array_)
- throw std::runtime_error("Value is not an array: " + dump());
- array_->push_back(v);
- }
- Value pop(const Value& index) {
- if (is_array()) {
- if (array_->empty())
- throw std::runtime_error("pop from empty list");
- if (index.is_null()) {
- auto ret = array_->back();
- array_->pop_back();
- return ret;
- } else if (!index.is_number_integer()) {
- throw std::runtime_error("pop index must be an integer: " + index.dump());
- } else {
- auto i = index.get<int>();
- if (i < 0 || i >= static_cast<int>(array_->size()))
- throw std::runtime_error("pop index out of range: " + index.dump());
- auto it = array_->begin() + (i < 0 ? array_->size() + i : i);
- auto ret = *it;
- array_->erase(it);
- return ret;
- }
- } else if (is_object()) {
- if (!index.is_hashable())
- throw std::runtime_error("Unashable type: " + index.dump());
- auto it = object_->find(index.primitive_);
- if (it == object_->end())
- throw std::runtime_error("Key not found: " + index.dump());
- auto ret = it->second;
- object_->erase(it);
- return ret;
- } else {
- throw std::runtime_error("Value is not an array or object: " + dump());
- }
- }
- Value get(const Value& key) {
- if (array_) {
- if (!key.is_number_integer()) {
- return Value();
- }
- auto index = key.get<int>();
- return array_->at(index < 0 ? array_->size() + index : index);
- } else if (object_) {
- if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
- auto it = object_->find(key.primitive_);
- if (it == object_->end()) return Value();
- return it->second;
- }
- return Value();
- }
- void set(const Value& key, const Value& value) {
- if (!object_) throw std::runtime_error("Value is not an object: " + dump());
- if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
- (*object_)[key.primitive_] = value;
- }
- Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
- if (!callable_) throw std::runtime_error("Value is not callable: " + dump());
- return (*callable_)(context, args);
- }
-
- bool is_object() const { return !!object_; }
- bool is_array() const { return !!array_; }
- bool is_callable() const { return !!callable_; }
- bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; }
- bool is_boolean() const { return primitive_.is_boolean(); }
- bool is_number_integer() const { return primitive_.is_number_integer(); }
- bool is_number_float() const { return primitive_.is_number_float(); }
- bool is_number() const { return primitive_.is_number(); }
- bool is_string() const { return primitive_.is_string(); }
- bool is_iterable() const { return is_array() || is_object() || is_string(); }
-
- bool is_primitive() const { return !array_ && !object_ && !callable_; }
- bool is_hashable() const { return is_primitive(); }
-
- bool empty() const {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (is_string()) return primitive_.empty();
- if (is_array()) return array_->empty();
- if (is_object()) return object_->empty();
- return false;
- }
-
- void for_each(const std::function<void(Value &)> & callback) const {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (array_) {
- for (auto& item : *array_) {
- callback(item);
- }
- } else if (object_) {
- for (auto & item : *object_) {
- Value key(item.first);
- callback(key);
- }
- } else if (is_string()) {
- for (char c : primitive_.get<std::string>()) {
- auto val = Value(std::string(1, c));
- callback(val);
- }
- } else {
- throw std::runtime_error("Value is not iterable: " + dump());
- }
- }
-
- bool to_bool() const {
- if (is_null()) return false;
- if (is_boolean()) return get<bool>();
- if (is_number()) return get<double>() != 0;
- if (is_string()) return !get<std::string>().empty();
- if (is_array()) return !empty();
- return true;
- }
-
- int64_t to_int() const {
- if (is_null()) return 0;
- if (is_boolean()) return get<bool>() ? 1 : 0;
- if (is_number()) return static_cast<int64_t>(get<double>());
- if (is_string()) {
- try {
- return std::stol(get<std::string>());
- } catch (const std::exception &) {
- return 0;
- }
- }
- return 0;
- }
-
- bool operator<(const Value & other) const {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (is_number() && other.is_number()) return get<double>() < other.get<double>();
- if (is_string() && other.is_string()) return get<std::string>() < other.get<std::string>();
- throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump());
- }
- bool operator>=(const Value & other) const { return !(*this < other); }
-
- bool operator>(const Value & other) const {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (is_number() && other.is_number()) return get<double>() > other.get<double>();
- if (is_string() && other.is_string()) return get<std::string>() > other.get<std::string>();
- throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump());
- }
- bool operator<=(const Value & other) const { return !(*this > other); }
-
- bool operator==(const Value & other) const {
- if (callable_ || other.callable_) {
- if (callable_.get() != other.callable_.get()) return false;
- }
- if (array_) {
- if (!other.array_) return false;
- if (array_->size() != other.array_->size()) return false;
- for (size_t i = 0; i < array_->size(); ++i) {
- if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false;
- }
- return true;
- } else if (object_) {
- if (!other.object_) return false;
- if (object_->size() != other.object_->size()) return false;
- for (const auto& item : *object_) {
- if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false;
- }
- return true;
- } else {
- return primitive_ == other.primitive_;
- }
- }
- bool operator!=(const Value & other) const { return !(*this == other); }
-
- bool contains(const char * key) const { return contains(std::string(key)); }
- bool contains(const std::string & key) const {
- if (array_) {
- return false;
- } else if (object_) {
- return object_->find(key) != object_->end();
- } else {
- throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
- }
- }
- bool contains(const Value & value) const {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (array_) {
- for (const auto& item : *array_) {
- if (item.to_bool() && item == value) return true;
- }
- return false;
- } else if (object_) {
- if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
- return object_->find(value.primitive_) != object_->end();
- } else {
- throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
- }
- }
- void erase(size_t index) {
- if (!array_) throw std::runtime_error("Value is not an array: " + dump());
- array_->erase(array_->begin() + index);
- }
- void erase(const std::string & key) {
- if (!object_) throw std::runtime_error("Value is not an object: " + dump());
- object_->erase(key);
- }
- const Value& at(const Value & index) const {
- return const_cast<Value*>(this)->at(index);
- }
- Value& at(const Value & index) {
- if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
- if (is_array()) return array_->at(index.get<int>());
- if (is_object()) return object_->at(index.primitive_);
- throw std::runtime_error("Value is not an array or object: " + dump());
- }
- const Value& at(size_t index) const {
- return const_cast<Value*>(this)->at(index);
- }
- Value& at(size_t index) {
- if (is_null())
- throw std::runtime_error("Undefined value or reference");
- if (is_array()) return array_->at(index);
- if (is_object()) return object_->at(index);
- throw std::runtime_error("Value is not an array or object: " + dump());
- }
-
- template <typename T>
- T get(const std::string & key, T default_value) const {
- if (!contains(key)) return default_value;
- return at(key).get<T>();
- }
-
- template <typename T>
- T get() const {
- if (is_primitive()) return primitive_.get<T>();
- throw std::runtime_error("get<T> not defined for this value type: " + dump());
- }
-
- std::string dump(int indent=-1, bool to_json=false) const {
- std::ostringstream out;
- dump(out, indent, 0, to_json);
- return out.str();
- }
-
- Value operator-() const {
- if (is_number_integer())
- return -get<int64_t>();
- else
- return -get<double>();
- }
- std::string to_str() const {
- if (is_string()) return get<std::string>();
- if (is_number_integer()) return std::to_string(get<int64_t>());
- if (is_number_float()) return std::to_string(get<double>());
- if (is_boolean()) return get<bool>() ? "True" : "False";
- if (is_null()) return "None";
- return dump();
- }
- Value operator+(const Value& rhs) const {
- if (is_string() || rhs.is_string()) {
- return to_str() + rhs.to_str();
- } else if (is_number_integer() && rhs.is_number_integer()) {
- return get<int64_t>() + rhs.get<int64_t>();
- } else if (is_array() && rhs.is_array()) {
- auto res = Value::array();
- for (const auto& item : *array_) res.push_back(item);
- for (const auto& item : *rhs.array_) res.push_back(item);
- return res;
- } else {
- return get<double>() + rhs.get<double>();
- }
- }
- Value operator-(const Value& rhs) const {
- if (is_number_integer() && rhs.is_number_integer())
- return get<int64_t>() - rhs.get<int64_t>();
- else
- return get<double>() - rhs.get<double>();
- }
- Value operator*(const Value& rhs) const {
- if (is_string() && rhs.is_number_integer()) {
- std::ostringstream out;
- for (int64_t i = 0, n = rhs.get<int64_t>(); i < n; ++i) {
- out << to_str();
- }
- return out.str();
- }
- else if (is_number_integer() && rhs.is_number_integer())
- return get<int64_t>() * rhs.get<int64_t>();
- else
- return get<double>() * rhs.get<double>();
- }
- Value operator/(const Value& rhs) const {
- if (is_number_integer() && rhs.is_number_integer())
- return get<int64_t>() / rhs.get<int64_t>();
- else
- return get<double>() / rhs.get<double>();
- }
- Value operator%(const Value& rhs) const {
- return get<int64_t>() % rhs.get<int64_t>();
- }
-};
-
-struct ArgumentsValue {
- std::vector<Value> args;
- std::vector<std::pair<std::string, Value>> kwargs;
-
- bool has_named(const std::string & name) {
- for (const auto & p : kwargs) {
- if (p.first == name) return true;
- }
- return false;
- }
-
- Value get_named(const std::string & name) {
- for (const auto & [key, value] : kwargs) {
- if (key == name) return value;
- }
- return Value();
- }
-
- bool empty() {
- return args.empty() && kwargs.empty();
- }
-
- void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) {
- if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
- std::ostringstream out;
- out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments";
- throw std::runtime_error(out.str());
- }
- }
-};
-
-template <>
-inline json Value::get<json>() const {
- if (is_primitive()) return primitive_;
- if (is_null()) return json();
- if (array_) {
- std::vector<json> res;
- for (const auto& item : *array_) {
- res.push_back(item.get<json>());
- }
- return res;
- }
- if (object_) {
- json res = json::object();
- for (const auto& [key, value] : *object_) {
- if (key.is_string()) {
- res[key.get<std::string>()] = value.get<json>();
- } else if (key.is_primitive()) {
- res[key.dump()] = value.get<json>();
- } else {
- throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
- }
- }
- if (is_callable()) {
- res["__callable__"] = true;
- }
- return res;
- }
- throw std::runtime_error("get<json> not defined for this value type: " + dump());
-}
-
-} // namespace minja
-
-namespace std {
- template <>
- struct hash<minja::Value> {
- size_t operator()(const minja::Value & v) const {
- if (!v.is_hashable())
- throw std::runtime_error("Unsupported type for hashing: " + v.dump());
- return std::hash<json>()(v.get<json>());
- }
- };
-} // namespace std
-
-namespace minja {
-
-static std::string error_location_suffix(const std::string & source, size_t pos) {
- auto get_line = [&](size_t line) {
- auto start = source.begin();
- for (size_t i = 1; i < line; ++i) {
- start = std::find(start, source.end(), '\n') + 1;
- }
- auto end = std::find(start, source.end(), '\n');
- return std::string(start, end);
- };
- auto start = source.begin();
- auto end = source.end();
- auto it = start + pos;
- auto line = std::count(start, it, '\n') + 1;
- auto max_line = std::count(start, end, '\n') + 1;
- auto col = pos - std::string(start, it).rfind('\n');
- std::ostringstream out;
- out << " at row " << line << ", column " << col << ":\n";
- if (line > 1) out << get_line(line - 1) << "\n";
- out << get_line(line) << "\n";
- out << std::string(col - 1, ' ') << "^\n";
- if (line < max_line) out << get_line(line + 1) << "\n";
-
- return out.str();
-}
-
-class Context : public std::enable_shared_from_this<Context> {
- protected:
- Value values_;
- std::shared_ptr<Context> parent_;
- public:
- Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
- if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
- }
- virtual ~Context() {}
-
- static std::shared_ptr<Context> builtins();
- static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
-
- std::vector<Value> keys() {
- return values_.keys();
- }
- virtual Value get(const Value & key) {
- if (values_.contains(key)) return values_.at(key);
- if (parent_) return parent_->get(key);
- return Value();
- }
- virtual Value & at(const Value & key) {
- if (values_.contains(key)) return values_.at(key);
- if (parent_) return parent_->at(key);
- throw std::runtime_error("Undefined variable: " + key.dump());
- }
- virtual bool contains(const Value & key) {
- if (values_.contains(key)) return true;
- if (parent_) return parent_->contains(key);
- return false;
- }
- virtual void set(const Value & key, const Value & value) {
- values_.set(key, value);
- }
-};
-
-struct Location {
- std::shared_ptr<std::string> source;
- size_t pos;
-};
-
-class Expression {
-protected:
- virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
-public:
- using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
-
- Location location;
-
- Expression(const Location & location) : location(location) {}
- virtual ~Expression() = default;
-
- Value evaluate(const std::shared_ptr<Context> & context) const {
- try {
- return do_evaluate(context);
- } catch (const std::exception & e) {
- std::ostringstream out;
- out << e.what();
- if (location.source) out << error_location_suffix(*location.source, location.pos);
- throw std::runtime_error(out.str());
- }
- }
-};
-
-class VariableExpr : public Expression {
- std::string name;
-public:
- VariableExpr(const Location & location, const std::string& n)
- : Expression(location), name(n) {}
- std::string get_name() const { return name; }
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!context->contains(name)) {
- return Value();
- }
- return context->at(name);
- }
-};
-
-static void destructuring_assign(const std::vector<std::string> & var_names, const std::shared_ptr<Context> & context, Value& item) {
- if (var_names.size() == 1) {
- Value name(var_names[0]);
- context->set(name, item);
- } else {
- if (!item.is_array() || item.size() != var_names.size()) {
- throw std::runtime_error("Mismatched number of variables and items in destructuring assignment");
- }
- for (size_t i = 0; i < var_names.size(); ++i) {
- context->set(var_names[i], item.at(i));
- }
- }
-}
-
-enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
-
-class TemplateToken {
-public:
- enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
-
- static std::string typeToString(Type t) {
- switch (t) {
- case Type::Text: return "text";
- case Type::Expression: return "expression";
- case Type::If: return "if";
- case Type::Else: return "else";
- case Type::Elif: return "elif";
- case Type::EndIf: return "endif";
- case Type::For: return "for";
- case Type::EndFor: return "endfor";
- case Type::Set: return "set";
- case Type::EndSet: return "endset";
- case Type::Comment: return "comment";
- case Type::Macro: return "macro";
- case Type::EndMacro: return "endmacro";
- case Type::Filter: return "filter";
- case Type::EndFilter: return "endfilter";
- case Type::Generation: return "generation";
- case Type::EndGeneration: return "endgeneration";
- case Type::Break: return "break";
- case Type::Continue: return "continue";
- }
- return "Unknown";
- }
-
- TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {}
- virtual ~TemplateToken() = default;
-
- Type type;
- Location location;
- SpaceHandling pre_space = SpaceHandling::Keep;
- SpaceHandling post_space = SpaceHandling::Keep;
-};
-
-struct TextTemplateToken : public TemplateToken {
- std::string text;
- TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {}
-};
-
-struct ExpressionTemplateToken : public TemplateToken {
- std::shared_ptr<Expression> expr;
- ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
-};
-
-struct IfTemplateToken : public TemplateToken {
- std::shared_ptr<Expression> condition;
- IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
-};
-
-struct ElifTemplateToken : public TemplateToken {
- std::shared_ptr<Expression> condition;
- ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
-};
-
-struct ElseTemplateToken : public TemplateToken {
- ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {}
-};
-
-struct EndIfTemplateToken : public TemplateToken {
- EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
-};
-
-struct MacroTemplateToken : public TemplateToken {
- std::shared_ptr<VariableExpr> name;
- Expression::Parameters params;
- MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
- : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
-};
-
-struct EndMacroTemplateToken : public TemplateToken {
- EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {}
-};
-
-struct FilterTemplateToken : public TemplateToken {
- std::shared_ptr<Expression> filter;
- FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
- : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {}
-};
-
-struct EndFilterTemplateToken : public TemplateToken {
- EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {}
-};
-
-struct ForTemplateToken : public TemplateToken {
- std::vector<std::string> var_names;
- std::shared_ptr<Expression> iterable;
- std::shared_ptr<Expression> condition;
- bool recursive;
- ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
- std::shared_ptr<Expression> && c, bool r)
- : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
-};
-
-struct EndForTemplateToken : public TemplateToken {
- EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
-};
-
-struct GenerationTemplateToken : public TemplateToken {
- GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
-};
-
-struct EndGenerationTemplateToken : public TemplateToken {
- EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
-};
-
-struct SetTemplateToken : public TemplateToken {
- std::string ns;
- std::vector<std::string> var_names;
- std::shared_ptr<Expression> value;
- SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
- : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
-};
-
-struct EndSetTemplateToken : public TemplateToken {
- EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {}
-};
-
-struct CommentTemplateToken : public TemplateToken {
- std::string text;
- CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
-};
-
-enum class LoopControlType { Break, Continue };
-
-class LoopControlException : public std::runtime_error {
-public:
- LoopControlType control_type;
- LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
- LoopControlException(LoopControlType control_type)
- : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
- control_type(control_type) {}
-};
-
-struct LoopControlTemplateToken : public TemplateToken {
- LoopControlType control_type;
- LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
-};
-
-class TemplateNode {
- Location location_;
-protected:
- virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
-
-public:
- TemplateNode(const Location & location) : location_(location) {}
- void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
- try {
- do_render(out, context);
- } catch (const LoopControlException & e) {
- // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
- std::ostringstream err;
- err << e.what();
- if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
- throw LoopControlException(err.str(), e.control_type);
- } catch (const std::exception & e) {
- std::ostringstream err;
- err << e.what();
- if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
- throw std::runtime_error(err.str());
- }
- }
- const Location & location() const { return location_; }
- virtual ~TemplateNode() = default;
- std::string render(const std::shared_ptr<Context> & context) const {
- std::ostringstream out;
- render(out, context);
- return out.str();
- }
-};
-
-class SequenceNode : public TemplateNode {
- std::vector<std::shared_ptr<TemplateNode>> children;
-public:
- SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
- : TemplateNode(location), children(std::move(c)) {}
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
- for (const auto& child : children) child->render(out, context);
- }
-};
-
-class TextNode : public TemplateNode {
- std::string text;
-public:
- TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {}
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
- out << text;
- }
-};
-
-class ExpressionNode : public TemplateNode {
- std::shared_ptr<Expression> expr;
-public:
- ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
- if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
- auto result = expr->evaluate(context);
- if (result.is_string()) {
- out << result.get<std::string>();
- } else if (result.is_boolean()) {
- out << (result.get<bool>() ? "True" : "False");
- } else if (!result.is_null()) {
- out << result.dump();
- }
- }
-};
-
-class IfNode : public TemplateNode {
- std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
-public:
- IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
- : TemplateNode(location), cascade(std::move(c)) {}
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
- for (const auto& branch : cascade) {
- auto enter_branch = true;
- if (branch.first) {
- enter_branch = branch.first->evaluate(context).to_bool();
- }
- if (enter_branch) {
- if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
- branch.second->render(out, context);
- return;
- }
- }
- }
-};
-
-class LoopControlNode : public TemplateNode {
- LoopControlType control_type_;
- public:
- LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
- void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
- throw LoopControlException(control_type_);
- }
-};
-
-class ForNode : public TemplateNode {
- std::vector<std::string> var_names;
- std::shared_ptr<Expression> iterable;
- std::shared_ptr<Expression> condition;
- std::shared_ptr<TemplateNode> body;
- bool recursive;
- std::shared_ptr<TemplateNode> else_body;
-public:
- ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
- std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
- : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
-
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
- // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
- if (!iterable) throw std::runtime_error("ForNode.iterable is null");
- if (!body) throw std::runtime_error("ForNode.body is null");
-
- auto iterable_value = iterable->evaluate(context);
- Value::CallableType loop_function;
-
- std::function<void(Value&)> visit = [&](Value& iter) {
- auto filtered_items = Value::array();
- if (!iter.is_null()) {
- if (!iterable_value.is_iterable()) {
- throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump());
- }
- iterable_value.for_each([&](Value & item) {
- destructuring_assign(var_names, context, item);
- if (!condition || condition->evaluate(context).to_bool()) {
- filtered_items.push_back(item);
- }
- });
- }
- if (filtered_items.empty()) {
- if (else_body) {
- else_body->render(out, context);
- }
- } else {
- auto loop = recursive ? Value::callable(loop_function) : Value::object();
- loop.set("length", (int64_t) filtered_items.size());
-
- size_t cycle_index = 0;
- loop.set("cycle", Value::callable([&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
- if (args.args.empty() || !args.kwargs.empty()) {
- throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg");
- }
- auto item = args.args[cycle_index];
- cycle_index = (cycle_index + 1) % args.args.size();
- return item;
- }));
- auto loop_context = Context::make(Value::object(), context);
- loop_context->set("loop", loop);
- for (size_t i = 0, n = filtered_items.size(); i < n; ++i) {
- auto & item = filtered_items.at(i);
- destructuring_assign(var_names, loop_context, item);
- loop.set("index", (int64_t) i + 1);
- loop.set("index0", (int64_t) i);
- loop.set("revindex", (int64_t) (n - i));
- loop.set("revindex0", (int64_t) (n - i - 1));
- loop.set("length", (int64_t) n);
- loop.set("first", i == 0);
- loop.set("last", i == (n - 1));
- loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
- loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
- try {
- body->render(out, loop_context);
- } catch (const LoopControlException & e) {
- if (e.control_type == LoopControlType::Break) break;
- if (e.control_type == LoopControlType::Continue) continue;
- }
- }
- }
- };
-
- if (recursive) {
- loop_function = [&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
- if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) {
- throw std::runtime_error("loop() expects exactly 1 positional iterable argument");
- }
- auto & items = args.args[0];
- visit(items);
- return Value();
- };
- }
-
- visit(iterable_value);
- }
-};
-
-class MacroNode : public TemplateNode {
- std::shared_ptr<VariableExpr> name;
- Expression::Parameters params;
- std::shared_ptr<TemplateNode> body;
- std::unordered_map<std::string, size_t> named_param_positions;
-public:
- MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
- : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
- for (size_t i = 0; i < params.size(); ++i) {
- const auto & name = params[i].first;
- if (!name.empty()) {
- named_param_positions[name] = i;
- }
- }
- }
- void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
- if (!name) throw std::runtime_error("MacroNode.name is null");
- if (!body) throw std::runtime_error("MacroNode.body is null");
- auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
- auto call_context = macro_context;
- std::vector<bool> param_set(params.size(), false);
- for (size_t i = 0, n = args.args.size(); i < n; i++) {
- auto & arg = args.args[i];
- if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
- param_set[i] = true;
- auto & param_name = params[i].first;
- call_context->set(param_name, arg);
- }
- for (auto & [arg_name, value] : args.kwargs) {
- auto it = named_param_positions.find(arg_name);
- if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
-
- call_context->set(arg_name, value);
- param_set[it->second] = true;
- }
- // Set default values for parameters that were not passed
- for (size_t i = 0, n = params.size(); i < n; i++) {
- if (!param_set[i] && params[i].second != nullptr) {
- auto val = params[i].second->evaluate(context);
- call_context->set(params[i].first, val);
- }
- }
- return body->render(call_context);
- });
- macro_context->set(name->get_name(), callable);
- }
-};
-
-class FilterNode : public TemplateNode {
- std::shared_ptr<Expression> filter;
- std::shared_ptr<TemplateNode> body;
-
-public:
- FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
- : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {}
-
- void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
- if (!filter) throw std::runtime_error("FilterNode.filter is null");
- if (!body) throw std::runtime_error("FilterNode.body is null");
- auto filter_value = filter->evaluate(context);
- if (!filter_value.is_callable()) {
- throw std::runtime_error("Filter must be a callable: " + filter_value.dump());
- }
- std::string rendered_body = body->render(context);
-
- ArgumentsValue filter_args = {{Value(rendered_body)}, {}};
- auto result = filter_value.call(context, filter_args);
- out << result.to_str();
- }
-};
-
-class SetNode : public TemplateNode {
- std::string ns;
- std::vector<std::string> var_names;
- std::shared_ptr<Expression> value;
-public:
- SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
- : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
- void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
- if (!value) throw std::runtime_error("SetNode.value is null");
- if (!ns.empty()) {
- if (var_names.size() != 1) {
- throw std::runtime_error("Namespaced set only supports a single variable name");
- }
- auto & name = var_names[0];
- auto ns_value = context->get(ns);
- if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
- ns_value.set(name, this->value->evaluate(context));
- } else {
- auto val = value->evaluate(context);
- destructuring_assign(var_names, context, val);
- }
- }
-};
-
-class SetTemplateNode : public TemplateNode {
- std::string name;
- std::shared_ptr<TemplateNode> template_value;
-public:
- SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
- : TemplateNode(location), name(name), template_value(std::move(tv)) {}
- void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
- if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
- Value value { template_value->render(context) };
- context->set(name, value);
- }
-};
-
-class IfExpr : public Expression {
- std::shared_ptr<Expression> condition;
- std::shared_ptr<Expression> then_expr;
- std::shared_ptr<Expression> else_expr;
-public:
- IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
- : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!condition) throw std::runtime_error("IfExpr.condition is null");
- if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
- if (condition->evaluate(context).to_bool()) {
- return then_expr->evaluate(context);
- }
- if (else_expr) {
- return else_expr->evaluate(context);
- }
- return nullptr;
- }
-};
-
-class LiteralExpr : public Expression {
- Value value;
-public:
- LiteralExpr(const Location & location, const Value& v)
- : Expression(location), value(v) {}
- Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
-};
-
-class ArrayExpr : public Expression {
- std::vector<std::shared_ptr<Expression>> elements;
-public:
- ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
- : Expression(location), elements(std::move(e)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- auto result = Value::array();
- for (const auto& e : elements) {
- if (!e) throw std::runtime_error("Array element is null");
- result.push_back(e->evaluate(context));
- }
- return result;
- }
-};
-
-class DictExpr : public Expression {
- std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
-public:
- DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
- : Expression(location), elements(std::move(e)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- auto result = Value::object();
- for (const auto& [key, value] : elements) {
- if (!key) throw std::runtime_error("Dict key is null");
- if (!value) throw std::runtime_error("Dict value is null");
- result.set(key->evaluate(context), value->evaluate(context));
- }
- return result;
- }
-};
-
-class SliceExpr : public Expression {
-public:
- std::shared_ptr<Expression> start, end;
- SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
- : Expression(location), start(std::move(s)), end(std::move(e)) {}
- Value do_evaluate(const std::shared_ptr<Context> &) const override {
- throw std::runtime_error("SliceExpr not implemented");
- }
-};
-
-class SubscriptExpr : public Expression {
- std::shared_ptr<Expression> base;
- std::shared_ptr<Expression> index;
-public:
- SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
- : Expression(location), base(std::move(b)), index(std::move(i)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!base) throw std::runtime_error("SubscriptExpr.base is null");
- if (!index) throw std::runtime_error("SubscriptExpr.index is null");
- auto target_value = base->evaluate(context);
- if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
- auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
- auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
- if (target_value.is_string()) {
- std::string s = target_value.get<std::string>();
- if (start < 0) start = s.size() + start;
- if (end < 0) end = s.size() + end;
- return s.substr(start, end - start);
- } else if (target_value.is_array()) {
- if (start < 0) start = target_value.size() + start;
- if (end < 0) end = target_value.size() + end;
- auto result = Value::array();
- for (auto i = start; i < end; ++i) {
- result.push_back(target_value.at(i));
- }
- return result;
- } else {
- throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings");
- }
- } else {
- auto index_value = index->evaluate(context);
- if (target_value.is_null()) {
- if (auto t = dynamic_cast<VariableExpr*>(base.get())) {
- throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined"));
- }
- throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!");
- }
- return target_value.get(index_value);
- }
- }
-};
-
-class UnaryOpExpr : public Expression {
-public:
- enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
- std::shared_ptr<Expression> expr;
- Op op;
- UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
- : Expression(location), expr(std::move(e)), op(o) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
- auto e = expr->evaluate(context);
- switch (op) {
- case Op::Plus: return e;
- case Op::Minus: return -e;
- case Op::LogicalNot: return !e.to_bool();
- case Op::Expansion:
- case Op::ExpansionDict:
- throw std::runtime_error("Expansion operator is only supported in function calls and collections");
-
- }
- throw std::runtime_error("Unknown unary operator");
- }
-};
-
-class BinaryOpExpr : public Expression {
-public:
- enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
-private:
- std::shared_ptr<Expression> left;
- std::shared_ptr<Expression> right;
- Op op;
-public:
- BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
- : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
- if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
- auto l = left->evaluate(context);
-
- auto do_eval = [&](const Value & l) -> Value {
- if (op == Op::Is || op == Op::IsNot) {
- auto t = dynamic_cast<VariableExpr*>(right.get());
- if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable");
-
- auto eval = [&]() {
- const auto & name = t->get_name();
- if (name == "none") return l.is_null();
- if (name == "boolean") return l.is_boolean();
- if (name == "integer") return l.is_number_integer();
- if (name == "float") return l.is_number_float();
- if (name == "number") return l.is_number();
- if (name == "string") return l.is_string();
- if (name == "mapping") return l.is_object();
- if (name == "iterable") return l.is_iterable();
- if (name == "sequence") return l.is_array();
- if (name == "defined") return !l.is_null();
- throw std::runtime_error("Unknown type for 'is' operator: " + name);
- };
- auto value = eval();
- return Value(op == Op::Is ? value : !value);
- }
-
- if (op == Op::And) {
- if (!l.to_bool()) return Value(false);
- return right->evaluate(context).to_bool();
- } else if (op == Op::Or) {
- if (l.to_bool()) return l;
- return right->evaluate(context);
- }
-
- auto r = right->evaluate(context);
- switch (op) {
- case Op::StrConcat: return l.to_str() + r.to_str();
- case Op::Add: return l + r;
- case Op::Sub: return l - r;
- case Op::Mul: return l * r;
- case Op::Div: return l / r;
- case Op::MulMul: return std::pow(l.get<double>(), r.get<double>());
- case Op::DivDiv: return l.get<int64_t>() / r.get<int64_t>();
- case Op::Mod: return l.get<int64_t>() % r.get<int64_t>();
- case Op::Eq: return l == r;
- case Op::Ne: return l != r;
- case Op::Lt: return l < r;
- case Op::Gt: return l > r;
- case Op::Le: return l <= r;
- case Op::Ge: return l >= r;
- case Op::In: return (r.is_array() || r.is_object()) && r.contains(l);
- case Op::NotIn: return !(r.is_array() && r.contains(l));
- default: break;
- }
- throw std::runtime_error("Unknown binary operator");
- };
-
- if (l.is_callable()) {
- return Value::callable([l, do_eval](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
- auto ll = l.call(context, args);
- return do_eval(ll); //args[0].second);
- });
- } else {
- return do_eval(l);
- }
- }
-};
-
-struct ArgumentsExpression {
- std::vector<std::shared_ptr<Expression>> args;
- std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
-
- ArgumentsValue evaluate(const std::shared_ptr<Context> & context) const {
- ArgumentsValue vargs;
- for (const auto& arg : this->args) {
- if (auto un_expr = std::dynamic_pointer_cast<UnaryOpExpr>(arg)) {
- if (un_expr->op == UnaryOpExpr::Op::Expansion) {
- auto array = un_expr->expr->evaluate(context);
- if (!array.is_array()) {
- throw std::runtime_error("Expansion operator only supported on arrays");
- }
- array.for_each([&](Value & value) {
- vargs.args.push_back(value);
- });
- continue;
- } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) {
- auto dict = un_expr->expr->evaluate(context);
- if (!dict.is_object()) {
- throw std::runtime_error("ExpansionDict operator only supported on objects");
- }
- dict.for_each([&](const Value & key) {
- vargs.kwargs.push_back({key.get<std::string>(), dict.at(key)});
- });
- continue;
- }
- }
- vargs.args.push_back(arg->evaluate(context));
- }
- for (const auto& [name, value] : this->kwargs) {
- vargs.kwargs.push_back({name, value->evaluate(context)});
- }
- return vargs;
- }
-};
-
-static std::string strip(const std::string & s) {
- auto start = s.find_first_not_of(" \t\n\r");
- if (start == std::string::npos) return "";
- auto end = s.find_last_not_of(" \t\n\r");
- return s.substr(start, end - start + 1);
-}
-
-static std::string capitalize(const std::string & s) {
- if (s.empty()) return s;
- auto result = s;
- result[0] = std::toupper(result[0]);
- return result;
-}
-
-static std::string html_escape(const std::string & s) {
- std::string result;
- result.reserve(s.size());
- for (const auto & c : s) {
- switch (c) {
- case '&': result += "&"; break;
- case '<': result += "<"; break;
- case '>': result += ">"; break;
- case '"': result += """; break;
- case '\'': result += "'"; break;
- default: result += c; break;
- }
- }
- return result;
-}
-
-class MethodCallExpr : public Expression {
- std::shared_ptr<Expression> object;
- std::shared_ptr<VariableExpr> method;
- ArgumentsExpression args;
-public:
- MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
- : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!object) throw std::runtime_error("MethodCallExpr.object is null");
- if (!method) throw std::runtime_error("MethodCallExpr.method is null");
- auto obj = object->evaluate(context);
- auto vargs = args.evaluate(context);
- if (obj.is_null()) {
- throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null");
- }
- if (obj.is_array()) {
- if (method->get_name() == "append") {
- vargs.expectArgs("append method", {1, 1}, {0, 0});
- obj.push_back(vargs.args[0]);
- return Value();
- } else if (method->get_name() == "pop") {
- vargs.expectArgs("pop method", {0, 1}, {0, 0});
- return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]);
- } else if (method->get_name() == "insert") {
- vargs.expectArgs("insert method", {2, 2}, {0, 0});
- auto index = vargs.args[0].get<int64_t>();
- if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method");
- obj.insert(index, vargs.args[1]);
- return Value();
- }
- } else if (obj.is_object()) {
- if (method->get_name() == "items") {
- vargs.expectArgs("items method", {0, 0}, {0, 0});
- auto result = Value::array();
- for (const auto& key : obj.keys()) {
- result.push_back(Value::array({key, obj.at(key)}));
- }
- return result;
- } else if (method->get_name() == "pop") {
- vargs.expectArgs("pop method", {1, 1}, {0, 0});
- return obj.pop(vargs.args[0]);
- } else if (method->get_name() == "get") {
- vargs.expectArgs("get method", {1, 2}, {0, 0});
- auto key = vargs.args[0];
- if (vargs.args.size() == 1) {
- return obj.contains(key) ? obj.at(key) : Value();
- } else {
- return obj.contains(key) ? obj.at(key) : vargs.args[1];
- }
- } else if (obj.contains(method->get_name())) {
- auto callable = obj.at(method->get_name());
- if (!callable.is_callable()) {
- throw std::runtime_error("Property '" + method->get_name() + "' is not callable");
- }
- return callable.call(context, vargs);
- }
- } else if (obj.is_string()) {
- auto str = obj.get<std::string>();
- if (method->get_name() == "strip") {
- vargs.expectArgs("strip method", {0, 0}, {0, 0});
- return Value(strip(str));
- } else if (method->get_name() == "capitalize") {
- vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
- return Value(capitalize(str));
- } else if (method->get_name() == "endswith") {
- vargs.expectArgs("endswith method", {1, 1}, {0, 0});
- auto suffix = vargs.args[0].get<std::string>();
- return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
- } else if (method->get_name() == "title") {
- vargs.expectArgs("title method", {0, 0}, {0, 0});
- auto res = str;
- for (size_t i = 0, n = res.size(); i < n; ++i) {
- if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]);
- else res[i] = std::tolower(res[i]);
- }
- return res;
- }
- }
- throw std::runtime_error("Unknown method: " + method->get_name());
- }
-};
-
-class CallExpr : public Expression {
-public:
- std::shared_ptr<Expression> object;
- ArgumentsExpression args;
- CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
- : Expression(location), object(std::move(obj)), args(std::move(a)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- if (!object) throw std::runtime_error("CallExpr.object is null");
- auto obj = object->evaluate(context);
- if (!obj.is_callable()) {
- throw std::runtime_error("Object is not callable: " + obj.dump(2));
- }
- auto vargs = args.evaluate(context);
- return obj.call(context, vargs);
- }
-};
-
-class FilterExpr : public Expression {
- std::vector<std::shared_ptr<Expression>> parts;
-public:
- FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
- : Expression(location), parts(std::move(p)) {}
- Value do_evaluate(const std::shared_ptr<Context> & context) const override {
- Value result;
- bool first = true;
- for (const auto& part : parts) {
- if (!part) throw std::runtime_error("FilterExpr.part is null");
- if (first) {
- first = false;
- result = part->evaluate(context);
- } else {
- if (auto ce = dynamic_cast<CallExpr*>(part.get())) {
- auto target = ce->object->evaluate(context);
- ArgumentsValue args = ce->args.evaluate(context);
- args.args.insert(args.args.begin(), result);
- result = target.call(context, args);
- } else {
- auto callable = part->evaluate(context);
- ArgumentsValue args;
- args.args.insert(args.args.begin(), result);
- result = callable.call(context, args);
- }
- }
- }
- return result;
- }
-
- void prepend(std::shared_ptr<Expression> && e) {
- parts.insert(parts.begin(), std::move(e));
- }
-};
-
-class Parser {
-private:
- using CharIterator = std::string::const_iterator;
-
- std::shared_ptr<std::string> template_str;
- CharIterator start, end, it;
- Options options;
-
- Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
- if (!template_str) throw std::runtime_error("Template string is null");
- start = it = this->template_str->begin();
- end = this->template_str->end();
- }
-
- bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) {
- if (space_handling == SpaceHandling::Strip) {
- while (it != end && std::isspace(*it)) ++it;
- }
- return true;
- }
-
- std::unique_ptr<std::string> parseString() {
- auto doParse = [&](char quote) -> std::unique_ptr<std::string> {
- if (it == end || *it != quote) return nullptr;
- std::string result;
- bool escape = false;
- for (++it; it != end; ++it) {
- if (escape) {
- escape = false;
- switch (*it) {
- case 'n': result += '\n'; break;
- case 'r': result += '\r'; break;
- case 't': result += '\t'; break;
- case 'b': result += '\b'; break;
- case 'f': result += '\f'; break;
- case '\\': result += '\\'; break;
- default:
- if (*it == quote) {
- result += quote;
- } else {
- result += *it;
- }
- break;
- }
- } else if (*it == '\\') {
- escape = true;
- } else if (*it == quote) {
- ++it;
- return std::make_unique<std::string>(std::move(result));
- } else {
- result += *it;
- }
- }
- return nullptr;
- };
-
- consumeSpaces();
- if (it == end) return nullptr;
- if (*it == '"') return doParse('"');
- if (*it == '\'') return doParse('\'');
- return nullptr;
- }
-
- json parseNumber(CharIterator& it, const CharIterator& end) {
- auto before = it;
- consumeSpaces();
- auto start = it;
- bool hasDecimal = false;
- bool hasExponent = false;
-
- if (it != end && (*it == '-' || *it == '+')) ++it;
-
- while (it != end) {
- if (std::isdigit(*it)) {
- ++it;
- } else if (*it == '.') {
- if (hasDecimal) throw std::runtime_error("Multiple decimal points");
- hasDecimal = true;
- ++it;
- } else if (it != start && (*it == 'e' || *it == 'E')) {
- if (hasExponent) throw std::runtime_error("Multiple exponents");
- hasExponent = true;
- ++it;
- } else {
- break;
- }
- }
- if (start == it) {
- it = before;
- return json(); // No valid characters found
- }
-
- std::string str(start, it);
- try {
- return json::parse(str);
- } catch (json::parse_error& e) {
- throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")");
- return json();
- }
- }
-
- /** integer, float, bool, string */
- std::shared_ptr<Value> parseConstant() {
- auto start = it;
- consumeSpaces();
- if (it == end) return nullptr;
- if (*it == '"' || *it == '\'') {
- auto str = parseString();
- if (str) return std::make_shared<Value>(*str);
- }
- static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
- auto token = consumeToken(prim_tok);
- if (!token.empty()) {
- if (token == "true" || token == "True") return std::make_shared<Value>(true);
- if (token == "false" || token == "False") return std::make_shared<Value>(false);
- if (token == "None") return std::make_shared<Value>(nullptr);
- throw std::runtime_error("Unknown constant token: " + token);
- }
-
- auto number = parseNumber(it, end);
- if (!number.is_null()) return std::make_shared<Value>(number);
-
- it = start;
- return nullptr;
- }
-
- class expression_parsing_error : public std::runtime_error {
- const CharIterator it;
- public:
- expression_parsing_error(const std::string & message, const CharIterator & it)
- : std::runtime_error(message), it(it) {}
- size_t get_pos(const CharIterator & begin) const {
- return std::distance(begin, it);
- }
- };
-
- bool peekSymbols(const std::vector<std::string> & symbols) const {
- for (const auto & symbol : symbols) {
- if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) {
- return true;
- }
- }
- return false;
- }
-
- std::vector<std::string> consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
- auto start = it;
- consumeSpaces(space_handling);
- std::smatch match;
- if (std::regex_search(it, end, match, regex) && match.position() == 0) {
- it += match[0].length();
- std::vector<std::string> ret;
- for (size_t i = 0, n = match.size(); i < n; ++i) {
- ret.push_back(match[i].str());
- }
- return ret;
- }
- it = start;
- return {};
- }
- std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
- auto start = it;
- consumeSpaces(space_handling);
- std::smatch match;
- if (std::regex_search(it, end, match, regex) && match.position() == 0) {
- it += match[0].length();
- return match[0].str();
- }
- it = start;
- return "";
- }
-
- std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) {
- auto start = it;
- consumeSpaces(space_handling);
- if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) {
- it += token.size();
- return token;
- }
- it = start;
- return "";
- }
-
- std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
- auto left = parseLogicalOr();
- if (it == end) return left;
-
- if (!allow_if_expr) return left;
-
- static std::regex if_tok(R"(if\b)");
- if (consumeToken(if_tok).empty()) {
- return left;
- }
-
- auto location = get_location();
- auto [condition, else_expr] = parseIfExpression();
- return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
- }
-
- Location get_location() const {
- return {template_str, (size_t) std::distance(start, it)};
- }
-
- std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
- auto condition = parseLogicalOr();
- if (!condition) throw std::runtime_error("Expected condition expression");
-
- static std::regex else_tok(R"(else\b)");
- std::shared_ptr<Expression> else_expr;
- if (!consumeToken(else_tok).empty()) {
- else_expr = parseExpression();
- if (!else_expr) throw std::runtime_error("Expected 'else' expression");
- }
- return std::pair(std::move(condition), std::move(else_expr));
- }
-
- std::shared_ptr<Expression> parseLogicalOr() {
- auto left = parseLogicalAnd();
- if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
-
- static std::regex or_tok(R"(or\b)");
- auto location = get_location();
- while (!consumeToken(or_tok).empty()) {
- auto right = parseLogicalAnd();
- if (!right) throw std::runtime_error("Expected right side of 'or' expression");
- left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
- }
- return left;
- }
-
- std::shared_ptr<Expression> parseLogicalNot() {
- static std::regex not_tok(R"(not\b)");
- auto location = get_location();
-
- if (!consumeToken(not_tok).empty()) {
- auto sub = parseLogicalNot();
- if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
- return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
- }
- return parseLogicalCompare();
- }
-
- std::shared_ptr<Expression> parseLogicalAnd() {
- auto left = parseLogicalNot();
- if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
-
- static std::regex and_tok(R"(and\b)");
- auto location = get_location();
- while (!consumeToken(and_tok).empty()) {
- auto right = parseLogicalNot();
- if (!right) throw std::runtime_error("Expected right side of 'and' expression");
- left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
- }
- return left;
- }
-
- std::shared_ptr<Expression> parseLogicalCompare() {
- auto left = parseStringConcat();
- if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
-
- static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
- static std::regex not_tok(R"(not\b)");
- std::string op_str;
- while (!(op_str = consumeToken(compare_tok)).empty()) {
- auto location = get_location();
- if (op_str == "is") {
- auto negated = !consumeToken(not_tok).empty();
-
- auto identifier = parseIdentifier();
- if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
-
- return std::make_shared<BinaryOpExpr>(
- left->location,
- std::move(left), std::move(identifier),
- negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
- }
- auto right = parseStringConcat();
- if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression");
- BinaryOpExpr::Op op;
- if (op_str == "==") op = BinaryOpExpr::Op::Eq;
- else if (op_str == "!=") op = BinaryOpExpr::Op::Ne;
- else if (op_str == "<") op = BinaryOpExpr::Op::Lt;
- else if (op_str == ">") op = BinaryOpExpr::Op::Gt;
- else if (op_str == "<=") op = BinaryOpExpr::Op::Le;
- else if (op_str == ">=") op = BinaryOpExpr::Op::Ge;
- else if (op_str == "in") op = BinaryOpExpr::Op::In;
- else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
- else throw std::runtime_error("Unknown comparison operator: " + op_str);
- left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
- }
- return left;
- }
-
- Expression::Parameters parseParameters() {
- consumeSpaces();
- if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
-
- Expression::Parameters result;
-
- while (it != end) {
- if (!consumeToken(")").empty()) {
- return result;
- }
- auto expr = parseExpression();
- if (!expr) throw std::runtime_error("Expected expression in call args");
-
- if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
- if (!consumeToken("=").empty()) {
- auto value = parseExpression();
- if (!value) throw std::runtime_error("Expected expression in for named arg");
- result.emplace_back(ident->get_name(), std::move(value));
- } else {
- result.emplace_back(ident->get_name(), nullptr);
- }
- } else {
- result.emplace_back(std::string(), std::move(expr));
- }
- if (consumeToken(",").empty()) {
- if (consumeToken(")").empty()) {
- throw std::runtime_error("Expected closing parenthesis in call args");
- }
- return result;
- }
- }
- throw std::runtime_error("Expected closing parenthesis in call args");
- }
-
- ArgumentsExpression parseCallArgs() {
- consumeSpaces();
- if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
-
- ArgumentsExpression result;
-
- while (it != end) {
- if (!consumeToken(")").empty()) {
- return result;
- }
- auto expr = parseExpression();
- if (!expr) throw std::runtime_error("Expected expression in call args");
-
- if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
- if (!consumeToken("=").empty()) {
- auto value = parseExpression();
- if (!value) throw std::runtime_error("Expected expression in for named arg");
- result.kwargs.emplace_back(ident->get_name(), std::move(value));
- } else {
- result.args.emplace_back(std::move(expr));
- }
- } else {
- result.args.emplace_back(std::move(expr));
- }
- if (consumeToken(",").empty()) {
- if (consumeToken(")").empty()) {
- throw std::runtime_error("Expected closing parenthesis in call args");
- }
- return result;
- }
- }
- throw std::runtime_error("Expected closing parenthesis in call args");
- }
-
- std::shared_ptr<VariableExpr> parseIdentifier() {
- static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
- auto location = get_location();
- auto ident = consumeToken(ident_regex);
- if (ident.empty())
- return nullptr;
- return std::make_shared<VariableExpr>(location, ident);
- }
-
- std::shared_ptr<Expression> parseStringConcat() {
- auto left = parseMathPow();
- if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
-
- static std::regex concat_tok(R"(~(?!\}))");
- if (!consumeToken(concat_tok).empty()) {
- auto right = parseLogicalAnd();
- if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
- left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
- }
- return left;
- }
-
- std::shared_ptr<Expression> parseMathPow() {
- auto left = parseMathPlusMinus();
- if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
-
- while (!consumeToken("**").empty()) {
- auto right = parseMathPlusMinus();
- if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
- left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
- }
- return left;
- }
-
- std::shared_ptr<Expression> parseMathPlusMinus() {
- static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
-
- auto left = parseMathMulDiv();
- if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression");
- std::string op_str;
- while (!(op_str = consumeToken(plus_minus_tok)).empty()) {
- auto right = parseMathMulDiv();
- if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
- auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
- left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
- }
- return left;
- }
-
- std::shared_ptr<Expression> parseMathMulDiv() {
- auto left = parseMathUnaryPlusMinus();
- if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
-
- static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))");
- std::string op_str;
- while (!(op_str = consumeToken(mul_div_tok)).empty()) {
- auto right = parseMathUnaryPlusMinus();
- if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
- auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
- : op_str == "**" ? BinaryOpExpr::Op::MulMul
- : op_str == "/" ? BinaryOpExpr::Op::Div
- : op_str == "//" ? BinaryOpExpr::Op::DivDiv
- : BinaryOpExpr::Op::Mod;
- left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
- }
-
- if (!consumeToken("|").empty()) {
- auto expr = parseMathMulDiv();
- if (auto filter = dynamic_cast<FilterExpr*>(expr.get())) {
- filter->prepend(std::move(left));
- return expr;
- } else {
- std::vector<std::shared_ptr<Expression>> parts;
- parts.emplace_back(std::move(left));
- parts.emplace_back(std::move(expr));
- return std::make_shared<FilterExpr>(get_location(), std::move(parts));
- }
- }
- return left;
- }
-
- std::shared_ptr<Expression> call_func(const std::string & name, ArgumentsExpression && args) const {
- return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
- }
-
- std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
- static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
- auto op_str = consumeToken(unary_plus_minus_tok);
- auto expr = parseExpansion();
- if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression");
-
- if (!op_str.empty()) {
- auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
- return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
- }
- return expr;
- }
-
- std::shared_ptr<Expression> parseExpansion() {
- static std::regex expansion_tok(R"(\*\*?)");
- auto op_str = consumeToken(expansion_tok);
- auto expr = parseValueExpression();
- if (op_str.empty()) return expr;
- if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression");
- return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict);
- }
-
- std::shared_ptr<Expression> parseValueExpression() {
- auto parseValue = [&]() -> std::shared_ptr<Expression> {
- auto location = get_location();
- auto constant = parseConstant();
- if (constant) return std::make_shared<LiteralExpr>(location, *constant);
-
- static std::regex null_regex(R"(null\b)");
- if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
-
- auto identifier = parseIdentifier();
- if (identifier) return identifier;
-
- auto braced = parseBracedExpressionOrArray();
- if (braced) return braced;
-
- auto array = parseArray();
- if (array) return array;
-
- auto dictionary = parseDictionary();
- if (dictionary) return dictionary;
-
- throw std::runtime_error("Expected value expression");
- };
-
- auto value = parseValue();
-
- while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
- if (!consumeToken("[").empty()) {
- std::shared_ptr<Expression> index;
- if (!consumeToken(":").empty()) {
- auto slice_end = parseExpression();
- index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
- } else {
- auto slice_start = parseExpression();
- if (!consumeToken(":").empty()) {
- consumeSpaces();
- if (peekSymbols({ "]" })) {
- index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
- } else {
- auto slice_end = parseExpression();
- index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
- }
- } else {
- index = std::move(slice_start);
- }
- }
- if (!index) throw std::runtime_error("Empty index in subscript");
- if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
-
- value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
- } else if (!consumeToken(".").empty()) {
- auto identifier = parseIdentifier();
- if (!identifier) throw std::runtime_error("Expected identifier in subscript");
-
- consumeSpaces();
- if (peekSymbols({ "(" })) {
- auto callParams = parseCallArgs();
- value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
- } else {
- auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
- value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
- }
- }
- consumeSpaces();
- }
-
- if (peekSymbols({ "(" })) {
- auto location = get_location();
- auto callParams = parseCallArgs();
- value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
- }
- return value;
- }
-
- std::shared_ptr<Expression> parseBracedExpressionOrArray() {
- if (consumeToken("(").empty()) return nullptr;
-
- auto expr = parseExpression();
- if (!expr) throw std::runtime_error("Expected expression in braced expression");
-
- if (!consumeToken(")").empty()) {
- return expr; // Drop the parentheses
- }
-
- std::vector<std::shared_ptr<Expression>> tuple;
- tuple.emplace_back(std::move(expr));
-
- while (it != end) {
- if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple");
- auto next = parseExpression();
- if (!next) throw std::runtime_error("Expected expression in tuple");
- tuple.push_back(std::move(next));
-
- if (!consumeToken(")").empty()) {
- return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
- }
- }
- throw std::runtime_error("Expected closing parenthesis");
- }
-
- std::shared_ptr<Expression> parseArray() {
- if (consumeToken("[").empty()) return nullptr;
-
- std::vector<std::shared_ptr<Expression>> elements;
- if (!consumeToken("]").empty()) {
- return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
- }
- auto first_expr = parseExpression();
- if (!first_expr) throw std::runtime_error("Expected first expression in array");
- elements.push_back(std::move(first_expr));
-
- while (it != end) {
- if (!consumeToken(",").empty()) {
- auto expr = parseExpression();
- if (!expr) throw std::runtime_error("Expected expression in array");
- elements.push_back(std::move(expr));
- } else if (!consumeToken("]").empty()) {
- return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
- } else {
- throw std::runtime_error("Expected comma or closing bracket in array");
- }
- }
- throw std::runtime_error("Expected closing bracket");
- }
-
- std::shared_ptr<Expression> parseDictionary() {
- if (consumeToken("{").empty()) return nullptr;
-
- std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
- if (!consumeToken("}").empty()) {
- return std::make_shared<DictExpr>(get_location(), std::move(elements));
- }
-
- auto parseKeyValuePair = [&]() {
- auto key = parseExpression();
- if (!key) throw std::runtime_error("Expected key in dictionary");
- if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
- auto value = parseExpression();
- if (!value) throw std::runtime_error("Expected value in dictionary");
- elements.emplace_back(std::pair(std::move(key), std::move(value)));
- };
-
- parseKeyValuePair();
-
- while (it != end) {
- if (!consumeToken(",").empty()) {
- parseKeyValuePair();
- } else if (!consumeToken("}").empty()) {
- return std::make_shared<DictExpr>(get_location(), std::move(elements));
- } else {
- throw std::runtime_error("Expected comma or closing brace in dictionary");
- }
- }
- throw std::runtime_error("Expected closing brace");
- }
-
- SpaceHandling parsePreSpace(const std::string& s) const {
- if (s == "-")
- return SpaceHandling::Strip;
- return SpaceHandling::Keep;
- }
-
- SpaceHandling parsePostSpace(const std::string& s) const {
- if (s == "-") return SpaceHandling::Strip;
- return SpaceHandling::Keep;
- }
-
- using TemplateTokenVector = std::vector<std::unique_ptr<TemplateToken>>;
- using TemplateTokenIterator = TemplateTokenVector::const_iterator;
-
- std::vector<std::string> parseVarNames() {
- static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
-
- std::vector<std::string> group;
- if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
- std::vector<std::string> varnames;
- std::istringstream iss(group[1]);
- std::string varname;
- while (std::getline(iss, varname, ',')) {
- varnames.push_back(strip(varname));
- }
- return varnames;
- }
-
- std::runtime_error unexpected(const TemplateToken & token) const {
- return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type)
- + error_location_suffix(*template_str, token.location.pos));
- }
- std::runtime_error unterminated(const TemplateToken & token) const {
- return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type)
- + error_location_suffix(*template_str, token.location.pos));
- }
-
- TemplateTokenVector tokenize() {
- static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
- static std::regex expr_open_regex(R"(\{\{([-~])?)");
- static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
- static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
- static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
- static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
- static std::regex block_close_regex(R"(\s*([-~])?%\})");
-
- TemplateTokenVector tokens;
- std::vector<std::string> group;
- std::string text;
- std::smatch match;
-
- try {
- while (it != end) {
- auto location = get_location();
-
- if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
- auto pre_space = parsePreSpace(group[1]);
- auto content = group[2];
- auto post_space = parsePostSpace(group[3]);
- tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
- } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
- auto pre_space = parsePreSpace(group[1]);
- auto expr = parseExpression();
-
- if ((group = consumeTokenGroups(expr_close_regex)).empty()) {
- throw std::runtime_error("Expected closing expression tag");
- }
-
- auto post_space = parsePostSpace(group[1]);
- tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
- } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
- auto pre_space = parsePreSpace(group[1]);
-
- std::string keyword;
-
- auto parseBlockClose = [&]() -> SpaceHandling {
- if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag");
- return parsePostSpace(group[1]);
- };
-
- if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
-
- if (keyword == "if") {
- auto condition = parseExpression();
- if (!condition) throw std::runtime_error("Expected condition in if block");
-
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
- } else if (keyword == "elif") {
- auto condition = parseExpression();
- if (!condition) throw std::runtime_error("Expected condition in elif block");
-
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
- } else if (keyword == "else") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "endif") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "for") {
- static std::regex recursive_tok(R"(recursive\b)");
- static std::regex if_tok(R"(if\b)");
-
- auto varnames = parseVarNames();
- static std::regex in_tok(R"(in\b)");
- if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block");
- auto iterable = parseExpression(/* allow_if_expr = */ false);
- if (!iterable) throw std::runtime_error("Expected iterable in for block");
-
- std::shared_ptr<Expression> condition;
- if (!consumeToken(if_tok).empty()) {
- condition = parseExpression();
- }
- auto recursive = !consumeToken(recursive_tok).empty();
-
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
- } else if (keyword == "endfor") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "generation") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "endgeneration") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "set") {
- static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
-
- std::string ns;
- std::vector<std::string> var_names;
- std::shared_ptr<Expression> value;
- if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
- ns = group[1];
- var_names.push_back(group[2]);
-
- if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
-
- value = parseExpression();
- if (!value) throw std::runtime_error("Expected value in set block");
- } else {
- var_names = parseVarNames();
-
- if (!consumeToken("=").empty()) {
- value = parseExpression();
- if (!value) throw std::runtime_error("Expected value in set block");
- }
- }
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
- } else if (keyword == "endset") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "macro") {
- auto macroname = parseIdentifier();
- if (!macroname) throw std::runtime_error("Expected macro name in macro block");
- auto params = parseParameters();
-
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
- } else if (keyword == "endmacro") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "filter") {
- auto filter = parseExpression();
- if (!filter) throw std::runtime_error("Expected expression in filter block");
-
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
- } else if (keyword == "endfilter") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
- } else if (keyword == "break" || keyword == "continue") {
- auto post_space = parseBlockClose();
- tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
- } else {
- throw std::runtime_error("Unexpected block: " + keyword);
- }
- } else if (std::regex_search(it, end, match, non_text_open_regex)) {
- if (!match.position()) {
- if (match[0] != "{#")
- throw std::runtime_error("Internal error: Expected a comment");
- throw std::runtime_error("Missing end of comment tag");
- }
- auto text_end = it + match.position();
- text = std::string(it, text_end);
- it = text_end;
- tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
- } else {
- text = std::string(it, end);
- it = end;
- tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
- }
- }
- return tokens;
- } catch (const std::exception & e) {
- throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it)));
- }
- }
-
- std::shared_ptr<TemplateNode> parseTemplate(
- const TemplateTokenIterator & begin,
- TemplateTokenIterator & it,
- const TemplateTokenIterator & end,
- bool fully = false) const {
- std::vector<std::shared_ptr<TemplateNode>> children;
- while (it != end) {
- const auto start = it;
- const auto & token = *(it++);
- if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
- std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
- cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
-
- while (it != end && (*it)->type == TemplateToken::Type::Elif) {
- auto elif_token = dynamic_cast<ElifTemplateToken*>((*(it++)).get());
- cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end));
- }
-
- if (it != end && (*it)->type == TemplateToken::Type::Else) {
- cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end));
- }
- if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
- throw unterminated(**start);
- }
- children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
- } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
- auto body = parseTemplate(begin, it, end);
- auto else_body = std::shared_ptr<TemplateNode>();
- if (it != end && (*it)->type == TemplateToken::Type::Else) {
- else_body = parseTemplate(begin, ++it, end);
- }
- if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
- throw unterminated(**start);
- }
- children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
- } else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
- auto body = parseTemplate(begin, it, end);
- if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
- throw unterminated(**start);
- }
- // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
- children.emplace_back(std::move(body));
- } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
- SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
- SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
-
- auto text = text_token->text;
- if (post_space == SpaceHandling::Strip) {
- static std::regex trailing_space_regex(R"(\s+$)");
- text = std::regex_replace(text, trailing_space_regex, "");
- } else if (options.lstrip_blocks && it != end) {
- auto i = text.size();
- while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
- if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
- text.resize(i);
- }
- }
- if (pre_space == SpaceHandling::Strip) {
- static std::regex leading_space_regex(R"(^\s+)");
- text = std::regex_replace(text, leading_space_regex, "");
- } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
- if (text.length() > 0 && text[0] == '\n') {
- text.erase(0, 1);
- }
- }
- if (it == end && !options.keep_trailing_newline) {
- auto i = text.size();
- if (i > 0 && text[i - 1] == '\n') {
- i--;
- if (i > 0 && text[i - 1] == '\r') i--;
- text.resize(i);
- }
- }
- children.emplace_back(std::make_shared<TextNode>(token->location, text));
- } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
- children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
- } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
- if (set_token->value) {
- children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
- } else {
- auto value_template = parseTemplate(begin, it, end);
- if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
- throw unterminated(**start);
- }
- if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
- if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
- auto & name = set_token->var_names[0];
- children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
- }
- } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
- auto body = parseTemplate(begin, it, end);
- if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
- throw unterminated(**start);
- }
- children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
- } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
- auto body = parseTemplate(begin, it, end);
- if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
- throw unterminated(**start);
- }
- children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
- } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
- // Ignore comments
- } else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
- children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
- } else if (dynamic_cast<EndForTemplateToken*>(token.get())
- || dynamic_cast<EndSetTemplateToken*>(token.get())
- || dynamic_cast<EndMacroTemplateToken*>(token.get())
- || dynamic_cast<EndFilterTemplateToken*>(token.get())
- || dynamic_cast<EndIfTemplateToken*>(token.get())
- || dynamic_cast<ElseTemplateToken*>(token.get())
- || dynamic_cast<EndGenerationTemplateToken*>(token.get())
- || dynamic_cast<ElifTemplateToken*>(token.get())) {
- it--; // unconsume the token
- break; // exit the loop
- } else {
- throw unexpected(**(it-1));
- }
- }
- if (fully && it != end) {
- throw unexpected(**it);
- }
- if (children.empty()) {
- return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
- } else if (children.size() == 1) {
- return std::move(children[0]);
- } else {
- return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
- }
- }
-
-public:
-
- static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
- Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
- auto tokens = parser.tokenize();
- TemplateTokenIterator begin = tokens.begin();
- auto it = begin;
- TemplateTokenIterator end = tokens.end();
- return parser.parseTemplate(begin, it, end, /* full= */ true);
- }
-};
-
-static Value simple_function(const std::string & fn_name, const std::vector<std::string> & params, const std::function<Value(const std::shared_ptr<Context> &, Value & args)> & fn) {
- std::map<std::string, size_t> named_positions;
- for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i;
-
- return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) -> Value {
- auto args_obj = Value::object();
- std::vector<bool> provided_args(params.size());
- for (size_t i = 0, n = args.args.size(); i < n; i++) {
- auto & arg = args.args[i];
- if (i < params.size()) {
- args_obj.set(params[i], arg);
- provided_args[i] = true;
- } else {
- throw std::runtime_error("Too many positional params for " + fn_name);
- }
- }
- for (auto & [name, value] : args.kwargs) {
- auto named_pos_it = named_positions.find(name);
- if (named_pos_it == named_positions.end()) {
- throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
- }
- provided_args[named_pos_it->second] = true;
- args_obj.set(name, value);
- }
- return fn(context, args_obj);
- });
-}
-
-inline std::shared_ptr<Context> Context::builtins() {
- auto globals = Value::object();
-
- globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- throw std::runtime_error(args.at("message").get<std::string>());
- }));
- globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
- return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true));
- }));
- globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto items = Value::array();
- if (args.contains("object")) {
- auto & obj = args.at("object");
- if (obj.is_string()) {
- auto json_obj = json::parse(obj.get<std::string>());
- for (const auto & kv : json_obj.items()) {
- items.push_back(Value::array({kv.key(), kv.value()}));
- }
- } else if (!obj.is_null()) {
- for (auto & key : obj.keys()) {
- items.push_back(Value::array({key, obj.at(key)}));
- }
- }
- }
- return items;
- }));
- globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto items = args.at("items");
- if (!items.is_array()) throw std::runtime_error("object is not a list");
- if (items.size() == 0) return Value();
- return items.at(items.size() - 1);
- }));
- globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto & text = args.at("text");
- return text.is_null() ? text : Value(strip(text.get<std::string>()));
- }));
- globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto text = args.at("text");
- if (text.is_null()) return text;
- std::string res;
- auto str = text.get<std::string>();
- std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower);
- return Value(res);
- }));
- globals.set("default", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
- args.expectArgs("default", {2, 3}, {0, 1});
- auto & value = args.args[0];
- auto & default_value = args.args[1];
- bool boolean = false;
- if (args.args.size() == 3) {
- boolean = args.args[2].get<bool>();
- } else {
- Value bv = args.get_named("boolean");
- if (!bv.is_null()) {
- boolean = bv.get<bool>();
- }
- }
- return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value;
- }));
- auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
- return Value(html_escape(args.at("text").get<std::string>()));
- });
- globals.set("e", escape);
- globals.set("escape", escape);
- globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto sep = args.get<std::string>("sep", "");
- auto first = std::make_shared<bool>(true);
- return simple_function("", {}, [sep, first](const std::shared_ptr<Context> &, const Value &) -> Value {
- if (*first) {
- *first = false;
- return "";
- }
- return sep;
- });
- return Value(html_escape(args.at("text").get<std::string>()));
- }));
- globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
- return Value((int64_t) args.at("items").size());
- }));
- globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr<Context> &, Value & args) {
- if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)");
- auto & value = args.at("value");
- auto keys = value.keys();
- std::sort(keys.begin(), keys.end());
- auto res = Value::array();
- for (auto & key : keys) {
- res.push_back(Value::array({key, value.at(key)}));
- }
- return res;
- }));
- globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto do_join = [](Value & items, const std::string & sep) {
- if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
- std::ostringstream oss;
- auto first = true;
- for (size_t i = 0, n = items.size(); i < n; ++i) {
- if (first) first = false;
- else oss << sep;
- oss << items.at(i).to_str();
- }
- return Value(oss.str());
- };
- auto sep = args.get<std::string>("d", "");
- if (args.contains("items")) {
- auto & items = args.at("items");
- return do_join(items, sep);
- } else {
- return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr<Context> &, Value & args) {
- auto & items = args.at("items");
- if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump());
- return do_join(items, sep);
- });
- }
- }));
- globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
- auto ns = Value::object();
- args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits<size_t>::max)()});
- for (auto & [name, value] : args.kwargs) {
- ns.set(name, value);
- }
- return ns;
- }));
- auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- return args.at("actual") == args.at("expected");
- });
- globals.set("equalto", equalto);
- globals.set("==", equalto);
- globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- auto & items = args.at("items");
- return (int64_t) items.size();
- }));
- globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- return args.at("value").to_str();
- }));
- globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- return args.at("value").to_str();
- }));
- globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- return args.at("value").to_int();
- }));
- globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- auto & items = args.at("items");
- if (!items.is_array()) throw std::runtime_error("object is not iterable");
- return items;
- }));
- globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
- auto & items = args.at("items");
- if (!items.is_array()) throw std::runtime_error("object is not iterable");
- std::unordered_set<Value> seen;
- auto result = Value::array();
- for (size_t i = 0, n = items.size(); i < n; i++) {
- auto pair = seen.insert(items.at(i));
- if (pair.second) {
- result.push_back(items.at(i));
- }
- }
- return result;
- }));
- auto make_filter = [](const Value & filter, Value & extra_args) -> Value {
- return simple_function("", { "value" }, [=](const std::shared_ptr<Context> & context, Value & args) {
- auto & value = args.at("value");
- ArgumentsValue actual_args;
- actual_args.args.emplace_back(value);
- for (size_t i = 0, n = extra_args.size(); i < n; i++) {
- actual_args.args.emplace_back(extra_args.at(i));
- }
- return filter.call(context, actual_args);
- });
- };
- auto select_or_reject = [make_filter](bool is_select) {
- return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
- args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
- auto & items = args.args[0];
- if (items.is_null())
- return Value::array();
- if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
-
- auto filter_fn = context->get(args.args[1]);
- if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
-
- auto filter_args = Value::array();
- for (size_t i = 2, n = args.args.size(); i < n; i++) {
- filter_args.push_back(args.args[i]);
- }
- auto filter = make_filter(filter_fn, filter_args);
-
- auto res = Value::array();
- for (size_t i = 0, n = items.size(); i < n; i++) {
- auto & item = items.at(i);
- ArgumentsValue filter_args;
- filter_args.args.emplace_back(item);
- auto pred_res = filter.call(context, filter_args);
- if (pred_res.to_bool() == (is_select ? true : false)) {
- res.push_back(item);
- }
- }
- return res;
- });
- };
- globals.set("select", select_or_reject(/* is_select= */ true));
- globals.set("reject", select_or_reject(/* is_select= */ false));
- globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
- auto res = Value::array();
- if (args.args.size() == 1 &&
- ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) {
- auto & items = args.args[0];
- auto attr_name = args.get_named("attribute");
- auto default_value = args.get_named("default");
- for (size_t i = 0, n = items.size(); i < n; i++) {
- auto & item = items.at(i);
- auto attr = item.get(attr_name);
- res.push_back(attr.is_null() ? default_value : attr);
- }
- } else if (args.kwargs.empty() && args.args.size() >= 2) {
- auto fn = context->get(args.args[1]);
- if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
- ArgumentsValue filter_args { {Value()}, {} };
- for (size_t i = 2, n = args.args.size(); i < n; i++) {
- filter_args.args.emplace_back(args.args[i]);
- }
- for (size_t i = 0, n = args.args[0].size(); i < n; i++) {
- auto & item = args.args[0].at(i);
- filter_args.args[0] = item;
- res.push_back(fn.call(context, filter_args));
- }
- } else {
- throw std::runtime_error("Invalid or unsupported arguments for map");
- }
- return res;
- }));
- globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr<Context> &, Value & args) {
- auto text = args.at("text").get<std::string>();
- auto first = args.get<bool>("first", false);
- std::string out;
- std::string indent(args.get<int64_t>("indent", 0), ' ');
- std::istringstream iss(text);
- std::string line;
- auto is_first = true;
- while (std::getline(iss, line, '\n')) {
- auto needs_indent = !is_first || first;
- if (is_first) is_first = false;
- else out += "\n";
- if (needs_indent) out += indent;
- out += line;
- }
- if (!text.empty() && text.back() == '\n') out += "\n";
- return out;
- }));
- auto select_or_reject_attr = [](bool is_select) {
- return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
- args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
- auto & items = args.args[0];
- if (items.is_null())
- return Value::array();
- if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
- auto attr_name = args.args[1].get<std::string>();
-
- bool has_test = false;
- Value test_fn;
- ArgumentsValue test_args {{Value()}, {}};
- if (args.args.size() >= 3) {
- has_test = true;
- test_fn = context->get(args.args[2]);
- if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
- for (size_t i = 3, n = args.args.size(); i < n; i++) {
- test_args.args.emplace_back(args.args[i]);
- }
- test_args.kwargs = args.kwargs;
- }
-
- auto res = Value::array();
- for (size_t i = 0, n = items.size(); i < n; i++) {
- auto & item = items.at(i);
- auto attr = item.get(attr_name);
- if (has_test) {
- test_args.args[0] = attr;
- if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
- res.push_back(item);
- }
- } else {
- res.push_back(attr);
- }
- }
- return res;
- });
- };
- globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
- globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
- globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
- std::vector<int64_t> startEndStep(3);
- std::vector<bool> param_set(3);
- if (args.args.size() == 1) {
- startEndStep[1] = args.args[0].get<int64_t>();
- param_set[1] = true;
- } else {
- for (size_t i = 0; i < args.args.size(); i++) {
- auto & arg = args.args[i];
- auto v = arg.get<int64_t>();
- startEndStep[i] = v;
- param_set[i] = true;
- }
- }
- for (auto & [name, value] : args.kwargs) {
- size_t i;
- if (name == "start") i = 0;
- else if (name == "end") i = 1;
- else if (name == "step") i = 2;
- else throw std::runtime_error("Unknown argument " + name + " for function range");
-
- if (param_set[i]) {
- throw std::runtime_error("Duplicate argument " + name + " for function range");
- }
- startEndStep[i] = value.get<int64_t>();
- param_set[i] = true;
- }
- if (!param_set[1]) {
- throw std::runtime_error("Missing required argument 'end' for function range");
- }
- int64_t start = param_set[0] ? startEndStep[0] : 0;
- int64_t end = startEndStep[1];
- int64_t step = param_set[2] ? startEndStep[2] : 1;
-
- auto res = Value::array();
- if (step > 0) {
- for (int64_t i = start; i < end; i += step) {
- res.push_back(Value(i));
- }
- } else {
- for (int64_t i = start; i > end; i += step) {
- res.push_back(Value(i));
- }
- }
- return res;
- }));
-
- return std::make_shared<Context>(std::move(globals));
-}
-
-inline std::shared_ptr<Context> Context::make(Value && values, const std::shared_ptr<Context> & parent) {
- return std::make_shared<Context>(values.is_null() ? Value::object() : std::move(values), parent);
-}
-
-} // namespace minja
--- /dev/null
+/*
+ Copyright 2024 Google LLC
+
+ Use of this source code is governed by an MIT-style
+ license that can be found in the LICENSE file or at
+ https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include "minja.hpp"
+#include <json.hpp>
+#include <string>
+#include <vector>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+struct chat_template_caps {
+ bool supports_tools = false;
+ bool supports_tool_calls = false;
+ bool supports_tool_responses = false;
+ bool supports_system_role = false;
+ bool supports_parallel_tool_calls = false;
+ bool supports_tool_call_id = false;
+ // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
+ // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
+ bool requires_object_arguments = false;
+ // CohereForAI/c4ai-command-r-plus simple variant
+ bool requires_non_null_content = false;
+ // MiniMaxAI/MiniMax-Text-01 special
+ bool requires_typed_content = false;
+};
+
+struct chat_template_inputs {
+ nlohmann::ordered_json messages;
+ nlohmann::ordered_json tools;
+ bool add_generation_prompt = true;
+ nlohmann::ordered_json extra_context;
+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
+};
+
+struct chat_template_options {
+ bool apply_polyfills = true;
+ bool use_bos_token = true;
+ bool use_eos_token = true;
+ bool define_strftime_now = true;
+
+ bool polyfill_tools = true;
+ bool polyfill_tool_call_examples = true;
+ bool polyfill_tool_calls = true;
+ bool polyfill_tool_responses = true;
+ bool polyfill_system_role = true;
+ bool polyfill_object_arguments = true;
+ bool polyfill_typed_content = true;
+};
+
+class chat_template {
+
+ private:
+ chat_template_caps caps_;
+ std::string source_;
+ std::string bos_token_;
+ std::string eos_token_;
+ std::shared_ptr<minja::TemplateNode> template_root_;
+ std::string tool_call_example_;
+
+ std::string try_raw_render(
+ const nlohmann::ordered_json & messages,
+ const nlohmann::ordered_json & tools,
+ bool add_generation_prompt,
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
+ {
+ try {
+ chat_template_inputs inputs;
+ inputs.messages = messages;
+ inputs.tools = tools;
+ inputs.add_generation_prompt = add_generation_prompt;
+ inputs.extra_context = extra_context;
+ // Use fixed date for tests
+ inputs.now = std::chrono::system_clock::from_time_t(0);
+
+ chat_template_options opts;
+ opts.apply_polyfills = false;
+
+ auto prompt = apply(inputs, opts);
+ // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
+ return prompt;
+ } catch (const std::exception & e) {
+ // fprintf(stderr, "try_raw_render error: %s\n", e.what());
+ return "";
+ }
+ }
+
+ public:
+
+ chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
+ : source_(source), bos_token_(bos_token), eos_token_(eos_token)
+ {
+ template_root_ = minja::Parser::parse(source_, {
+ /* .trim_blocks = */ true,
+ /* .lstrip_blocks = */ true,
+ /* .keep_trailing_newline = */ false,
+ });
+
+ auto contains = [](const std::string & haystack, const std::string & needle) {
+ return haystack.find(needle) != std::string::npos;
+ };
+
+ const std::string user_needle = "<User Needle>";
+ const std::string sys_needle = "<System Needle>";
+ const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
+ const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
+
+ caps_.requires_typed_content =
+ !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
+ && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
+
+ const auto dummy_user_msg = caps_.requires_typed_content
+ ? dummy_typed_user_msg
+ : dummy_str_user_msg;
+ const json needle_system_msg = {
+ {"role", "system"},
+ {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
+ };
+
+ caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
+
+ auto out = try_raw_render(json::array({
+ dummy_user_msg
+ }), json::array({
+ {
+ {"name", "some_tool"},
+ {"type", "function"},
+ {"function", {
+ {"name", "some_tool"},
+ {"description", "Some tool."},
+ {"parameters", {
+ {"type", "object"},
+ {"properties", {
+ {"arg", {
+ {"type", "string"},
+ {"description", "Some argument."},
+ }},
+ }},
+ {"required", json::array({ "arg" })},
+ }},
+ }},
+ },
+ }), false);
+ caps_.supports_tools = contains(out, "some_tool");
+
+ auto make_tool_calls_msg = [&](const json & tool_calls) {
+ return json {
+ {"role", "assistant"},
+ {"content", nullptr},
+ {"tool_calls", tool_calls},
+ };
+ };
+ auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
+ return json {
+ {"id", "call_1___"},
+ {"type", "function"},
+ {"function", {
+ {"arguments", arguments},
+ {"name", tool_name},
+ }},
+ };
+ };
+ const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
+
+ // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
+ out = try_raw_render(json::array({
+ dummy_user_msg,
+ make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
+ }), {}, false);
+ auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+ out = try_raw_render(json::array({
+ dummy_user_msg,
+ make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
+ }), {}, false);
+ auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+
+ caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
+ caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
+ auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
+ auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
+ caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
+
+ if (caps_.supports_tool_calls) {
+ auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
+ auto tc1 = make_tool_call("test_tool1", dummy_args);
+ auto tc2 = make_tool_call("test_tool2", dummy_args);
+ auto out = try_raw_render(json::array({
+ dummy_user_msg,
+ make_tool_calls_msg(json::array({tc1, tc2})),
+ }), {}, false);
+ caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
+
+ out = try_raw_render(json::array({
+ dummy_user_msg,
+ make_tool_calls_msg(json::array({tc1})),
+ {
+ {"role", "tool"},
+ {"name", "test_tool1"},
+ {"content", "Some response!"},
+ {"tool_call_id", "call_911_"},
+ }
+ }), {}, false);
+ caps_.supports_tool_responses = contains(out, "Some response!");
+ caps_.supports_tool_call_id = contains(out, "call_911_");
+ }
+
+ try {
+ if (!caps_.supports_tools) {
+ const json user_msg {
+ {"role", "user"},
+ {"content", "Hey"},
+ };
+ const json args {
+ {"arg1", "some_value"},
+ };
+ const json tool_call_msg {
+ {"role", "assistant"},
+ {"content", nullptr},
+ {"tool_calls", json::array({
+ {
+ // TODO: detect if requires numerical id or fixed length == 6 like Nemo
+ {"id", "call_1___"},
+ {"type", "function"},
+ {"function", {
+ {"name", "tool_name"},
+ {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
+ }},
+ },
+ })},
+ };
+ std::string prefix, full;
+ {
+ chat_template_inputs inputs;
+ inputs.messages = json::array({user_msg});
+ inputs.add_generation_prompt = true;
+ prefix = apply(inputs);
+ }
+ {
+ chat_template_inputs inputs;
+ inputs.messages = json::array({user_msg, tool_call_msg});
+ inputs.add_generation_prompt = false;
+ full = apply(inputs);
+ }
+ auto eos_pos_last = full.rfind(eos_token_);
+ if (eos_pos_last == prefix.size() - eos_token_.size() ||
+ (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
+ full = full.substr(0, eos_pos_last);
+ }
+ size_t common_prefix_length = 0;
+ for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
+ if (prefix[i] != full[i]) {
+ break;
+ }
+ if (prefix[i] == '<') {
+ // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
+ // but it removes thinking tags for past messages.
+ // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
+ continue;
+ }
+ common_prefix_length = i + 1;
+ }
+ auto example = full.substr(common_prefix_length);
+ if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
+ fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
+ } else {
+ tool_call_example_ = example;
+ }
+ }
+ } catch (const std::exception & e) {
+ fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
+ }
+ }
+
+ const std::string & source() const { return source_; }
+ const std::string & bos_token() const { return bos_token_; }
+ const std::string & eos_token() const { return eos_token_; }
+ const chat_template_caps & original_caps() const { return caps_; }
+
+ // Deprecated, please use the form with chat_template_inputs and chat_template_options
+ std::string apply(
+ const nlohmann::ordered_json & messages,
+ const nlohmann::ordered_json & tools,
+ bool add_generation_prompt,
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
+ bool apply_polyfills = true)
+ {
+ fprintf(stderr, "[%s] Deprecated!\n", __func__);
+ chat_template_inputs inputs;
+ inputs.messages = messages;
+ inputs.tools = tools;
+ inputs.add_generation_prompt = add_generation_prompt;
+ inputs.extra_context = extra_context;
+ inputs.now = std::chrono::system_clock::now();
+
+ chat_template_options opts;
+ opts.apply_polyfills = apply_polyfills;
+
+ return apply(inputs, opts);
+ }
+
+ std::string apply(
+ const chat_template_inputs & inputs,
+ const chat_template_options & opts = chat_template_options()) const
+ {
+ json actual_messages;
+
+ auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
+ auto has_tool_calls = false;
+ auto has_tool_responses = false;
+ auto has_string_content = false;
+ for (const auto & message : inputs.messages) {
+ if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
+ has_tool_calls = true;
+ }
+ if (message.contains("role") && message["role"] == "tool") {
+ has_tool_responses = true;
+ }
+ if (message.contains("content") && message["content"].is_string()) {
+ has_string_content = true;
+ }
+ }
+
+ auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
+ auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
+ auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
+ auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
+ auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
+ auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
+ auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
+
+ auto needs_polyfills = opts.apply_polyfills && (false
+ || polyfill_system_role
+ || polyfill_tools
+ || polyfill_tool_calls
+ || polyfill_tool_responses
+ || polyfill_object_arguments
+ || polyfill_typed_content
+ );
+
+ if (needs_polyfills) {
+ actual_messages = json::array();
+
+ auto add_message = [&](const json & msg) {
+ if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
+ actual_messages.push_back({
+ {"role", msg.at("role")},
+ {"content", {{
+ {"type", "text"},
+ {"text", msg.at("content")},
+ }}},
+ });
+ } else {
+ actual_messages.push_back(msg);
+ }
+ };
+
+ std::string pending_system;
+ auto flush_sys = [&]() {
+ if (!pending_system.empty()) {
+ add_message({
+ {"role", "user"},
+ {"content", pending_system},
+ });
+ pending_system.clear();
+ }
+ };
+
+ json adjusted_messages;
+ if (polyfill_tools) {
+ adjusted_messages = add_system(inputs.messages,
+ "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
+ (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
+ } else {
+ adjusted_messages = inputs.messages;
+ }
+
+ for (const auto & message_ : adjusted_messages) {
+ auto message = message_;
+ if (!message.contains("role") || !message.contains("content")) {
+ throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
+ }
+ std::string role = message.at("role");
+
+ if (message.contains("tool_calls")) {
+ if (polyfill_object_arguments || polyfill_tool_calls) {
+ for (auto & tool_call : message.at("tool_calls")) {
+ if (tool_call["type"] == "function") {
+ auto & function = tool_call.at("function");
+ auto & arguments = function.at("arguments");
+ if (arguments.is_string()) {
+ try {
+ arguments = json::parse(arguments.get<std::string>());
+ } catch (const std::exception & ecvt) {
+ fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
+ }
+ }
+ }
+ }
+ }
+ if (polyfill_tool_calls) {
+ auto content = message.at("content");
+ auto tool_calls = json::array();
+ for (const auto & tool_call : message.at("tool_calls")) {
+ if (tool_call.at("type") != "function") {
+ continue;
+ }
+ const auto & function = tool_call.at("function");
+ auto tc = json {
+ {"name", function.at("name")},
+ {"arguments", function.at("arguments")},
+ };
+ if (tool_call.contains("id")) {
+ tc["id"] = tool_call["id"];
+ }
+ tool_calls.push_back(tc);
+ }
+ auto obj = json {
+ {"tool_calls", tool_calls},
+ };
+ if (!content.is_null() && content != "") {
+ obj["content"] = content;
+ }
+ message["content"] = obj.dump(2);
+ message.erase("tool_calls");
+ }
+ }
+ if (polyfill_tool_responses && role == "tool") {
+ message["role"] = "user";
+ auto obj = json {
+ {"tool_response", {
+ {"content", message.at("content")},
+ }},
+ };
+ if (message.contains("name")) {
+ obj["tool_response"]["name"] = message.at("name");
+ }
+ if (message.contains("tool_call_id")) {
+ obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
+ }
+ message["content"] = obj.dump(2);
+ message.erase("name");
+ }
+
+ if (!message["content"].is_null() && polyfill_system_role) {
+ std::string content = message.at("content");
+ if (role == "system") {
+ if (!pending_system.empty()) pending_system += "\n";
+ pending_system += content;
+ continue;
+ } else {
+ if (role == "user") {
+ if (!pending_system.empty()) {
+ message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
+ pending_system.clear();
+ }
+ } else {
+ flush_sys();
+ }
+ }
+ }
+ add_message(message);
+ }
+ flush_sys();
+ } else {
+ actual_messages = inputs.messages;
+ }
+
+ auto context = minja::Context::make(json({
+ {"messages", actual_messages},
+ {"add_generation_prompt", inputs.add_generation_prompt},
+ }));
+ context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
+ context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
+ if (opts.define_strftime_now) {
+ auto now = inputs.now;
+ context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
+ args.expectArgs("strftime_now", {1, 1}, {0, 0});
+ auto format = args.args[0].get<std::string>();
+
+ auto time = std::chrono::system_clock::to_time_t(now);
+ auto local_time = *std::localtime(&time);
+ std::ostringstream ss;
+ ss << std::put_time(&local_time, format.c_str());
+ return ss.str();
+ }));
+ }
+ if (!inputs.tools.is_null()) {
+ context->set("tools", minja::Value(inputs.tools));
+ }
+ if (!inputs.extra_context.is_null()) {
+ for (auto & kv : inputs.extra_context.items()) {
+ context->set(kv.key(), minja::Value(kv.value()));
+ }
+ }
+
+ auto ret = template_root_->render(context);
+ // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
+ // fprintf(stderr, "apply: %s\n\n", ret.c_str());
+ return ret;
+ }
+
+ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
+ json messages_with_system = messages;
+
+ if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
+ std::string existing_system = messages_with_system.at(0).at("content");
+ messages_with_system[0] = json {
+ {"role", "system"},
+ {"content", existing_system + "\n\n" + system_prompt},
+ };
+ } else {
+ messages_with_system.insert(messages_with_system.begin(), json {
+ {"role", "system"},
+ {"content", system_prompt},
+ });
+ }
+ return messages_with_system;
+ }
+};
+
+} // namespace minja
--- /dev/null
+/*
+ Copyright 2024 Google LLC
+
+ Use of this source code is governed by an MIT-style
+ license that can be found in the LICENSE file or at
+ https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <regex>
+#include <memory>
+#include <stdexcept>
+#include <sstream>
+#include <unordered_set>
+#include <json.hpp>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+class Context;
+
+struct Options {
+ bool trim_blocks; // removes the first newline after a block
+ bool lstrip_blocks; // removes leading whitespace on the line of the block
+ bool keep_trailing_newline; // don't remove last newline
+};
+
+struct ArgumentsValue;
+
+inline std::string normalize_newlines(const std::string & s) {
+#ifdef _WIN32
+ static const std::regex nl_regex("\r\n");
+ return std::regex_replace(s, nl_regex, "\n");
+#else
+ return s;
+#endif
+}
+
+/* Values that behave roughly like in Python. */
+class Value : public std::enable_shared_from_this<Value> {
+public:
+ using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+ using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+
+private:
+ using ObjectType = nlohmann::ordered_map<json, Value>; // Only contains primitive keys
+ using ArrayType = std::vector<Value>;
+
+ std::shared_ptr<ArrayType> array_;
+ std::shared_ptr<ObjectType> object_;
+ std::shared_ptr<CallableType> callable_;
+ json primitive_;
+
+ Value(const std::shared_ptr<ArrayType> & array) : array_(array) {}
+ Value(const std::shared_ptr<ObjectType> & object) : object_(object) {}
+ Value(const std::shared_ptr<CallableType> & callable) : object_(std::make_shared<ObjectType>()), callable_(callable) {}
+
+ /* Python-style string repr */
+ static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') {
+ if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump());
+ auto s = primitive.dump();
+ if (string_quote == '"' || s.find('\'') != std::string::npos) {
+ out << s;
+ return;
+ }
+ // Reuse json dump, just changing string quotes
+ out << string_quote;
+ for (size_t i = 1, n = s.size() - 1; i < n; ++i) {
+ if (s[i] == '\\' && s[i + 1] == '"') {
+ out << '"';
+ i++;
+ } else if (s[i] == string_quote) {
+ out << '\\' << string_quote;
+ } else {
+ out << s[i];
+ }
+ }
+ out << string_quote;
+ }
+ void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
+ auto print_indent = [&](int level) {
+ if (indent > 0) {
+ out << "\n";
+ for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
+ }
+ };
+ auto print_sub_sep = [&]() {
+ out << ',';
+ if (indent < 0) out << ' ';
+ else print_indent(level + 1);
+ };
+
+ auto string_quote = to_json ? '"' : '\'';
+
+ if (is_null()) out << "null";
+ else if (array_) {
+ out << "[";
+ print_indent(level + 1);
+ for (size_t i = 0; i < array_->size(); ++i) {
+ if (i) print_sub_sep();
+ (*array_)[i].dump(out, indent, level + 1, to_json);
+ }
+ print_indent(level);
+ out << "]";
+ } else if (object_) {
+ out << "{";
+ print_indent(level + 1);
+ for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) {
+ if (it != begin) print_sub_sep();
+ if (it->first.is_string()) {
+ dump_string(it->first, out, string_quote);
+ } else {
+ out << string_quote << it->first.dump() << string_quote;
+ }
+ out << ": ";
+ it->second.dump(out, indent, level + 1, to_json);
+ }
+ print_indent(level);
+ out << "}";
+ } else if (callable_) {
+ throw std::runtime_error("Cannot dump callable to JSON");
+ } else if (is_boolean() && !to_json) {
+ out << (this->to_bool() ? "True" : "False");
+ } else if (is_string() && !to_json) {
+ dump_string(primitive_, out, string_quote);
+ } else {
+ out << primitive_.dump();
+ }
+ }
+
+public:
+ Value() {}
+ Value(const bool& v) : primitive_(v) {}
+ Value(const int64_t & v) : primitive_(v) {}
+ Value(const double& v) : primitive_(v) {}
+ Value(const std::nullptr_t &) {}
+ Value(const std::string & v) : primitive_(v) {}
+ Value(const char * v) : primitive_(std::string(v)) {}
+
+ Value(const json & v) {
+ if (v.is_object()) {
+ auto object = std::make_shared<ObjectType>();
+ for (auto it = v.begin(); it != v.end(); ++it) {
+ (*object)[it.key()] = it.value();
+ }
+ object_ = std::move(object);
+ } else if (v.is_array()) {
+ auto array = std::make_shared<ArrayType>();
+ for (const auto& item : v) {
+ array->push_back(Value(item));
+ }
+ array_ = array;
+ } else {
+ primitive_ = v;
+ }
+ }
+
+ std::vector<Value> keys() {
+ if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+ std::vector<Value> res;
+ for (const auto& item : *object_) {
+ res.push_back(item.first);
+ }
+ return res;
+ }
+
+ size_t size() const {
+ if (is_object()) return object_->size();
+ if (is_array()) return array_->size();
+ if (is_string()) return primitive_.get<std::string>().length();
+ throw std::runtime_error("Value is not an array or object: " + dump());
+ }
+
+ static Value array(const std::vector<Value> values = {}) {
+ auto array = std::make_shared<ArrayType>();
+ for (const auto& item : values) {
+ array->push_back(item);
+ }
+ return Value(array);
+ }
+ static Value object(const std::shared_ptr<ObjectType> object = std::make_shared<ObjectType>()) {
+ return Value(object);
+ }
+ static Value callable(const CallableType & callable) {
+ return Value(std::make_shared<CallableType>(callable));
+ }
+
+ void insert(size_t index, const Value& v) {
+ if (!array_)
+ throw std::runtime_error("Value is not an array: " + dump());
+ array_->insert(array_->begin() + index, v);
+ }
+ void push_back(const Value& v) {
+ if (!array_)
+ throw std::runtime_error("Value is not an array: " + dump());
+ array_->push_back(v);
+ }
+ Value pop(const Value& index) {
+ if (is_array()) {
+ if (array_->empty())
+ throw std::runtime_error("pop from empty list");
+ if (index.is_null()) {
+ auto ret = array_->back();
+ array_->pop_back();
+ return ret;
+ } else if (!index.is_number_integer()) {
+ throw std::runtime_error("pop index must be an integer: " + index.dump());
+ } else {
+ auto i = index.get<int>();
+ if (i < 0 || i >= static_cast<int>(array_->size()))
+ throw std::runtime_error("pop index out of range: " + index.dump());
+ auto it = array_->begin() + (i < 0 ? array_->size() + i : i);
+ auto ret = *it;
+ array_->erase(it);
+ return ret;
+ }
+ } else if (is_object()) {
+ if (!index.is_hashable())
+ throw std::runtime_error("Unashable type: " + index.dump());
+ auto it = object_->find(index.primitive_);
+ if (it == object_->end())
+ throw std::runtime_error("Key not found: " + index.dump());
+ auto ret = it->second;
+ object_->erase(it);
+ return ret;
+ } else {
+ throw std::runtime_error("Value is not an array or object: " + dump());
+ }
+ }
+ Value get(const Value& key) {
+ if (array_) {
+ if (!key.is_number_integer()) {
+ return Value();
+ }
+ auto index = key.get<int>();
+ return array_->at(index < 0 ? array_->size() + index : index);
+ } else if (object_) {
+ if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+ auto it = object_->find(key.primitive_);
+ if (it == object_->end()) return Value();
+ return it->second;
+ }
+ return Value();
+ }
+ void set(const Value& key, const Value& value) {
+ if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+ if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+ (*object_)[key.primitive_] = value;
+ }
+ Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
+ if (!callable_) throw std::runtime_error("Value is not callable: " + dump());
+ return (*callable_)(context, args);
+ }
+
+ bool is_object() const { return !!object_; }
+ bool is_array() const { return !!array_; }
+ bool is_callable() const { return !!callable_; }
+ bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; }
+ bool is_boolean() const { return primitive_.is_boolean(); }
+ bool is_number_integer() const { return primitive_.is_number_integer(); }
+ bool is_number_float() const { return primitive_.is_number_float(); }
+ bool is_number() const { return primitive_.is_number(); }
+ bool is_string() const { return primitive_.is_string(); }
+ bool is_iterable() const { return is_array() || is_object() || is_string(); }
+
+ bool is_primitive() const { return !array_ && !object_ && !callable_; }
+ bool is_hashable() const { return is_primitive(); }
+
+ bool empty() const {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (is_string()) return primitive_.empty();
+ if (is_array()) return array_->empty();
+ if (is_object()) return object_->empty();
+ return false;
+ }
+
+ void for_each(const std::function<void(Value &)> & callback) const {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (array_) {
+ for (auto& item : *array_) {
+ callback(item);
+ }
+ } else if (object_) {
+ for (auto & item : *object_) {
+ Value key(item.first);
+ callback(key);
+ }
+ } else if (is_string()) {
+ for (char c : primitive_.get<std::string>()) {
+ auto val = Value(std::string(1, c));
+ callback(val);
+ }
+ } else {
+ throw std::runtime_error("Value is not iterable: " + dump());
+ }
+ }
+
+ bool to_bool() const {
+ if (is_null()) return false;
+ if (is_boolean()) return get<bool>();
+ if (is_number()) return get<double>() != 0;
+ if (is_string()) return !get<std::string>().empty();
+ if (is_array()) return !empty();
+ return true;
+ }
+
+ int64_t to_int() const {
+ if (is_null()) return 0;
+ if (is_boolean()) return get<bool>() ? 1 : 0;
+ if (is_number()) return static_cast<int64_t>(get<double>());
+ if (is_string()) {
+ try {
+ return std::stol(get<std::string>());
+ } catch (const std::exception &) {
+ return 0;
+ }
+ }
+ return 0;
+ }
+
+ bool operator<(const Value & other) const {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (is_number() && other.is_number()) return get<double>() < other.get<double>();
+ if (is_string() && other.is_string()) return get<std::string>() < other.get<std::string>();
+ throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump());
+ }
+ bool operator>=(const Value & other) const { return !(*this < other); }
+
+ bool operator>(const Value & other) const {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (is_number() && other.is_number()) return get<double>() > other.get<double>();
+ if (is_string() && other.is_string()) return get<std::string>() > other.get<std::string>();
+ throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump());
+ }
+ bool operator<=(const Value & other) const { return !(*this > other); }
+
+ bool operator==(const Value & other) const {
+ if (callable_ || other.callable_) {
+ if (callable_.get() != other.callable_.get()) return false;
+ }
+ if (array_) {
+ if (!other.array_) return false;
+ if (array_->size() != other.array_->size()) return false;
+ for (size_t i = 0; i < array_->size(); ++i) {
+ if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false;
+ }
+ return true;
+ } else if (object_) {
+ if (!other.object_) return false;
+ if (object_->size() != other.object_->size()) return false;
+ for (const auto& item : *object_) {
+ if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false;
+ }
+ return true;
+ } else {
+ return primitive_ == other.primitive_;
+ }
+ }
+ bool operator!=(const Value & other) const { return !(*this == other); }
+
+ bool contains(const char * key) const { return contains(std::string(key)); }
+ bool contains(const std::string & key) const {
+ if (array_) {
+ return false;
+ } else if (object_) {
+ return object_->find(key) != object_->end();
+ } else {
+ throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+ }
+ }
+ bool contains(const Value & value) const {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (array_) {
+ for (const auto& item : *array_) {
+ if (item.to_bool() && item == value) return true;
+ }
+ return false;
+ } else if (object_) {
+ if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
+ return object_->find(value.primitive_) != object_->end();
+ } else {
+ throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+ }
+ }
+ void erase(size_t index) {
+ if (!array_) throw std::runtime_error("Value is not an array: " + dump());
+ array_->erase(array_->begin() + index);
+ }
+ void erase(const std::string & key) {
+ if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+ object_->erase(key);
+ }
+ const Value& at(const Value & index) const {
+ return const_cast<Value*>(this)->at(index);
+ }
+ Value& at(const Value & index) {
+ if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+ if (is_array()) return array_->at(index.get<int>());
+ if (is_object()) return object_->at(index.primitive_);
+ throw std::runtime_error("Value is not an array or object: " + dump());
+ }
+ const Value& at(size_t index) const {
+ return const_cast<Value*>(this)->at(index);
+ }
+ Value& at(size_t index) {
+ if (is_null())
+ throw std::runtime_error("Undefined value or reference");
+ if (is_array()) return array_->at(index);
+ if (is_object()) return object_->at(index);
+ throw std::runtime_error("Value is not an array or object: " + dump());
+ }
+
+ template <typename T>
+ T get(const std::string & key, T default_value) const {
+ if (!contains(key)) return default_value;
+ return at(key).get<T>();
+ }
+
+ template <typename T>
+ T get() const {
+ if (is_primitive()) return primitive_.get<T>();
+ throw std::runtime_error("get<T> not defined for this value type: " + dump());
+ }
+
+ std::string dump(int indent=-1, bool to_json=false) const {
+ std::ostringstream out;
+ dump(out, indent, 0, to_json);
+ return out.str();
+ }
+
+ Value operator-() const {
+ if (is_number_integer())
+ return -get<int64_t>();
+ else
+ return -get<double>();
+ }
+ std::string to_str() const {
+ if (is_string()) return get<std::string>();
+ if (is_number_integer()) return std::to_string(get<int64_t>());
+ if (is_number_float()) return std::to_string(get<double>());
+ if (is_boolean()) return get<bool>() ? "True" : "False";
+ if (is_null()) return "None";
+ return dump();
+ }
+ Value operator+(const Value& rhs) const {
+ if (is_string() || rhs.is_string()) {
+ return to_str() + rhs.to_str();
+ } else if (is_number_integer() && rhs.is_number_integer()) {
+ return get<int64_t>() + rhs.get<int64_t>();
+ } else if (is_array() && rhs.is_array()) {
+ auto res = Value::array();
+ for (const auto& item : *array_) res.push_back(item);
+ for (const auto& item : *rhs.array_) res.push_back(item);
+ return res;
+ } else {
+ return get<double>() + rhs.get<double>();
+ }
+ }
+ Value operator-(const Value& rhs) const {
+ if (is_number_integer() && rhs.is_number_integer())
+ return get<int64_t>() - rhs.get<int64_t>();
+ else
+ return get<double>() - rhs.get<double>();
+ }
+ Value operator*(const Value& rhs) const {
+ if (is_string() && rhs.is_number_integer()) {
+ std::ostringstream out;
+ for (int64_t i = 0, n = rhs.get<int64_t>(); i < n; ++i) {
+ out << to_str();
+ }
+ return out.str();
+ }
+ else if (is_number_integer() && rhs.is_number_integer())
+ return get<int64_t>() * rhs.get<int64_t>();
+ else
+ return get<double>() * rhs.get<double>();
+ }
+ Value operator/(const Value& rhs) const {
+ if (is_number_integer() && rhs.is_number_integer())
+ return get<int64_t>() / rhs.get<int64_t>();
+ else
+ return get<double>() / rhs.get<double>();
+ }
+ Value operator%(const Value& rhs) const {
+ return get<int64_t>() % rhs.get<int64_t>();
+ }
+};
+
+struct ArgumentsValue {
+ std::vector<Value> args;
+ std::vector<std::pair<std::string, Value>> kwargs;
+
+ bool has_named(const std::string & name) {
+ for (const auto & p : kwargs) {
+ if (p.first == name) return true;
+ }
+ return false;
+ }
+
+ Value get_named(const std::string & name) {
+ for (const auto & [key, value] : kwargs) {
+ if (key == name) return value;
+ }
+ return Value();
+ }
+
+ bool empty() {
+ return args.empty() && kwargs.empty();
+ }
+
+ void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) {
+ if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
+ std::ostringstream out;
+ out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments";
+ throw std::runtime_error(out.str());
+ }
+ }
+};
+
+template <>
+inline json Value::get<json>() const {
+ if (is_primitive()) return primitive_;
+ if (is_null()) return json();
+ if (array_) {
+ std::vector<json> res;
+ for (const auto& item : *array_) {
+ res.push_back(item.get<json>());
+ }
+ return res;
+ }
+ if (object_) {
+ json res = json::object();
+ for (const auto& [key, value] : *object_) {
+ if (key.is_string()) {
+ res[key.get<std::string>()] = value.get<json>();
+ } else if (key.is_primitive()) {
+ res[key.dump()] = value.get<json>();
+ } else {
+ throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
+ }
+ }
+ if (is_callable()) {
+ res["__callable__"] = true;
+ }
+ return res;
+ }
+ throw std::runtime_error("get<json> not defined for this value type: " + dump());
+}
+
+} // namespace minja
+
+namespace std {
+ template <>
+ struct hash<minja::Value> {
+ size_t operator()(const minja::Value & v) const {
+ if (!v.is_hashable())
+ throw std::runtime_error("Unsupported type for hashing: " + v.dump());
+ return std::hash<json>()(v.get<json>());
+ }
+ };
+} // namespace std
+
+namespace minja {
+
+static std::string error_location_suffix(const std::string & source, size_t pos) {
+ auto get_line = [&](size_t line) {
+ auto start = source.begin();
+ for (size_t i = 1; i < line; ++i) {
+ start = std::find(start, source.end(), '\n') + 1;
+ }
+ auto end = std::find(start, source.end(), '\n');
+ return std::string(start, end);
+ };
+ auto start = source.begin();
+ auto end = source.end();
+ auto it = start + pos;
+ auto line = std::count(start, it, '\n') + 1;
+ auto max_line = std::count(start, end, '\n') + 1;
+ auto col = pos - std::string(start, it).rfind('\n');
+ std::ostringstream out;
+ out << " at row " << line << ", column " << col << ":\n";
+ if (line > 1) out << get_line(line - 1) << "\n";
+ out << get_line(line) << "\n";
+ out << std::string(col - 1, ' ') << "^\n";
+ if (line < max_line) out << get_line(line + 1) << "\n";
+
+ return out.str();
+}
+
+class Context : public std::enable_shared_from_this<Context> {
+ protected:
+ Value values_;
+ std::shared_ptr<Context> parent_;
+ public:
+ Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
+ if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
+ }
+ virtual ~Context() {}
+
+ static std::shared_ptr<Context> builtins();
+ static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
+
+ std::vector<Value> keys() {
+ return values_.keys();
+ }
+ virtual Value get(const Value & key) {
+ if (values_.contains(key)) return values_.at(key);
+ if (parent_) return parent_->get(key);
+ return Value();
+ }
+ virtual Value & at(const Value & key) {
+ if (values_.contains(key)) return values_.at(key);
+ if (parent_) return parent_->at(key);
+ throw std::runtime_error("Undefined variable: " + key.dump());
+ }
+ virtual bool contains(const Value & key) {
+ if (values_.contains(key)) return true;
+ if (parent_) return parent_->contains(key);
+ return false;
+ }
+ virtual void set(const Value & key, const Value & value) {
+ values_.set(key, value);
+ }
+};
+
+struct Location {
+ std::shared_ptr<std::string> source;
+ size_t pos;
+};
+
+class Expression {
+protected:
+ virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
+public:
+ using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
+
+ Location location;
+
+ Expression(const Location & location) : location(location) {}
+ virtual ~Expression() = default;
+
+ Value evaluate(const std::shared_ptr<Context> & context) const {
+ try {
+ return do_evaluate(context);
+ } catch (const std::exception & e) {
+ std::ostringstream out;
+ out << e.what();
+ if (location.source) out << error_location_suffix(*location.source, location.pos);
+ throw std::runtime_error(out.str());
+ }
+ }
+};
+
+class VariableExpr : public Expression {
+ std::string name;
+public:
+ VariableExpr(const Location & location, const std::string& n)
+ : Expression(location), name(n) {}
+ std::string get_name() const { return name; }
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!context->contains(name)) {
+ return Value();
+ }
+ return context->at(name);
+ }
+};
+
+static void destructuring_assign(const std::vector<std::string> & var_names, const std::shared_ptr<Context> & context, Value& item) {
+ if (var_names.size() == 1) {
+ Value name(var_names[0]);
+ context->set(name, item);
+ } else {
+ if (!item.is_array() || item.size() != var_names.size()) {
+ throw std::runtime_error("Mismatched number of variables and items in destructuring assignment");
+ }
+ for (size_t i = 0; i < var_names.size(); ++i) {
+ context->set(var_names[i], item.at(i));
+ }
+ }
+}
+
+enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
+
+class TemplateToken {
+public:
+ enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
+
+ static std::string typeToString(Type t) {
+ switch (t) {
+ case Type::Text: return "text";
+ case Type::Expression: return "expression";
+ case Type::If: return "if";
+ case Type::Else: return "else";
+ case Type::Elif: return "elif";
+ case Type::EndIf: return "endif";
+ case Type::For: return "for";
+ case Type::EndFor: return "endfor";
+ case Type::Set: return "set";
+ case Type::EndSet: return "endset";
+ case Type::Comment: return "comment";
+ case Type::Macro: return "macro";
+ case Type::EndMacro: return "endmacro";
+ case Type::Filter: return "filter";
+ case Type::EndFilter: return "endfilter";
+ case Type::Generation: return "generation";
+ case Type::EndGeneration: return "endgeneration";
+ case Type::Break: return "break";
+ case Type::Continue: return "continue";
+ }
+ return "Unknown";
+ }
+
+ TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {}
+ virtual ~TemplateToken() = default;
+
+ Type type;
+ Location location;
+ SpaceHandling pre_space = SpaceHandling::Keep;
+ SpaceHandling post_space = SpaceHandling::Keep;
+};
+
+struct TextTemplateToken : public TemplateToken {
+ std::string text;
+ TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {}
+};
+
+struct ExpressionTemplateToken : public TemplateToken {
+ std::shared_ptr<Expression> expr;
+ ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
+};
+
+struct IfTemplateToken : public TemplateToken {
+ std::shared_ptr<Expression> condition;
+ IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElifTemplateToken : public TemplateToken {
+ std::shared_ptr<Expression> condition;
+ ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElseTemplateToken : public TemplateToken {
+ ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {}
+};
+
+struct EndIfTemplateToken : public TemplateToken {
+ EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
+};
+
+struct MacroTemplateToken : public TemplateToken {
+ std::shared_ptr<VariableExpr> name;
+ Expression::Parameters params;
+ MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
+ : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
+};
+
+struct EndMacroTemplateToken : public TemplateToken {
+ EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {}
+};
+
+struct FilterTemplateToken : public TemplateToken {
+ std::shared_ptr<Expression> filter;
+ FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
+ : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {}
+};
+
+struct EndFilterTemplateToken : public TemplateToken {
+ EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {}
+};
+
+struct ForTemplateToken : public TemplateToken {
+ std::vector<std::string> var_names;
+ std::shared_ptr<Expression> iterable;
+ std::shared_ptr<Expression> condition;
+ bool recursive;
+ ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
+ std::shared_ptr<Expression> && c, bool r)
+ : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
+};
+
+struct EndForTemplateToken : public TemplateToken {
+ EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
+};
+
+struct GenerationTemplateToken : public TemplateToken {
+ GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
+};
+
+struct EndGenerationTemplateToken : public TemplateToken {
+ EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
+};
+
+struct SetTemplateToken : public TemplateToken {
+ std::string ns;
+ std::vector<std::string> var_names;
+ std::shared_ptr<Expression> value;
+ SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+ : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
+};
+
+struct EndSetTemplateToken : public TemplateToken {
+ EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {}
+};
+
+struct CommentTemplateToken : public TemplateToken {
+ std::string text;
+ CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
+};
+
+enum class LoopControlType { Break, Continue };
+
+class LoopControlException : public std::runtime_error {
+public:
+ LoopControlType control_type;
+ LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
+ LoopControlException(LoopControlType control_type)
+ : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
+ control_type(control_type) {}
+};
+
+struct LoopControlTemplateToken : public TemplateToken {
+ LoopControlType control_type;
+ LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
+};
+
+class TemplateNode {
+ Location location_;
+protected:
+ virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
+
+public:
+ TemplateNode(const Location & location) : location_(location) {}
+ void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
+ try {
+ do_render(out, context);
+ } catch (const LoopControlException & e) {
+ // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
+ std::ostringstream err;
+ err << e.what();
+ if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
+ throw LoopControlException(err.str(), e.control_type);
+ } catch (const std::exception & e) {
+ std::ostringstream err;
+ err << e.what();
+ if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
+ throw std::runtime_error(err.str());
+ }
+ }
+ const Location & location() const { return location_; }
+ virtual ~TemplateNode() = default;
+ std::string render(const std::shared_ptr<Context> & context) const {
+ std::ostringstream out;
+ render(out, context);
+ return out.str();
+ }
+};
+
+class SequenceNode : public TemplateNode {
+ std::vector<std::shared_ptr<TemplateNode>> children;
+public:
+ SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
+ : TemplateNode(location), children(std::move(c)) {}
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+ for (const auto& child : children) child->render(out, context);
+ }
+};
+
+class TextNode : public TemplateNode {
+ std::string text;
+public:
+ TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {}
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
+ out << text;
+ }
+};
+
+class ExpressionNode : public TemplateNode {
+ std::shared_ptr<Expression> expr;
+public:
+ ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+ if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
+ auto result = expr->evaluate(context);
+ if (result.is_string()) {
+ out << result.get<std::string>();
+ } else if (result.is_boolean()) {
+ out << (result.get<bool>() ? "True" : "False");
+ } else if (!result.is_null()) {
+ out << result.dump();
+ }
+ }
+};
+
+class IfNode : public TemplateNode {
+ std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+public:
+ IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
+ : TemplateNode(location), cascade(std::move(c)) {}
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+ for (const auto& branch : cascade) {
+ auto enter_branch = true;
+ if (branch.first) {
+ enter_branch = branch.first->evaluate(context).to_bool();
+ }
+ if (enter_branch) {
+ if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
+ branch.second->render(out, context);
+ return;
+ }
+ }
+ }
+};
+
+class LoopControlNode : public TemplateNode {
+ LoopControlType control_type_;
+ public:
+ LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
+ void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
+ throw LoopControlException(control_type_);
+ }
+};
+
+class ForNode : public TemplateNode {
+ std::vector<std::string> var_names;
+ std::shared_ptr<Expression> iterable;
+ std::shared_ptr<Expression> condition;
+ std::shared_ptr<TemplateNode> body;
+ bool recursive;
+ std::shared_ptr<TemplateNode> else_body;
+public:
+ ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
+ std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
+ : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
+
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+ // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
+ if (!iterable) throw std::runtime_error("ForNode.iterable is null");
+ if (!body) throw std::runtime_error("ForNode.body is null");
+
+ auto iterable_value = iterable->evaluate(context);
+ Value::CallableType loop_function;
+
+ std::function<void(Value&)> visit = [&](Value& iter) {
+ auto filtered_items = Value::array();
+ if (!iter.is_null()) {
+ if (!iterable_value.is_iterable()) {
+ throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump());
+ }
+ iterable_value.for_each([&](Value & item) {
+ destructuring_assign(var_names, context, item);
+ if (!condition || condition->evaluate(context).to_bool()) {
+ filtered_items.push_back(item);
+ }
+ });
+ }
+ if (filtered_items.empty()) {
+ if (else_body) {
+ else_body->render(out, context);
+ }
+ } else {
+ auto loop = recursive ? Value::callable(loop_function) : Value::object();
+ loop.set("length", (int64_t) filtered_items.size());
+
+ size_t cycle_index = 0;
+ loop.set("cycle", Value::callable([&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+ if (args.args.empty() || !args.kwargs.empty()) {
+ throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg");
+ }
+ auto item = args.args[cycle_index];
+ cycle_index = (cycle_index + 1) % args.args.size();
+ return item;
+ }));
+ auto loop_context = Context::make(Value::object(), context);
+ loop_context->set("loop", loop);
+ for (size_t i = 0, n = filtered_items.size(); i < n; ++i) {
+ auto & item = filtered_items.at(i);
+ destructuring_assign(var_names, loop_context, item);
+ loop.set("index", (int64_t) i + 1);
+ loop.set("index0", (int64_t) i);
+ loop.set("revindex", (int64_t) (n - i));
+ loop.set("revindex0", (int64_t) (n - i - 1));
+ loop.set("length", (int64_t) n);
+ loop.set("first", i == 0);
+ loop.set("last", i == (n - 1));
+ loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
+ loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
+ try {
+ body->render(out, loop_context);
+ } catch (const LoopControlException & e) {
+ if (e.control_type == LoopControlType::Break) break;
+ if (e.control_type == LoopControlType::Continue) continue;
+ }
+ }
+ }
+ };
+
+ if (recursive) {
+ loop_function = [&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+ if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) {
+ throw std::runtime_error("loop() expects exactly 1 positional iterable argument");
+ }
+ auto & items = args.args[0];
+ visit(items);
+ return Value();
+ };
+ }
+
+ visit(iterable_value);
+ }
+};
+
+class MacroNode : public TemplateNode {
+ std::shared_ptr<VariableExpr> name;
+ Expression::Parameters params;
+ std::shared_ptr<TemplateNode> body;
+ std::unordered_map<std::string, size_t> named_param_positions;
+public:
+ MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
+ : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
+ for (size_t i = 0; i < params.size(); ++i) {
+ const auto & name = params[i].first;
+ if (!name.empty()) {
+ named_param_positions[name] = i;
+ }
+ }
+ }
+ void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
+ if (!name) throw std::runtime_error("MacroNode.name is null");
+ if (!body) throw std::runtime_error("MacroNode.body is null");
+ auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+ auto call_context = macro_context;
+ std::vector<bool> param_set(params.size(), false);
+ for (size_t i = 0, n = args.args.size(); i < n; i++) {
+ auto & arg = args.args[i];
+ if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
+ param_set[i] = true;
+ auto & param_name = params[i].first;
+ call_context->set(param_name, arg);
+ }
+ for (auto & [arg_name, value] : args.kwargs) {
+ auto it = named_param_positions.find(arg_name);
+ if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
+
+ call_context->set(arg_name, value);
+ param_set[it->second] = true;
+ }
+ // Set default values for parameters that were not passed
+ for (size_t i = 0, n = params.size(); i < n; i++) {
+ if (!param_set[i] && params[i].second != nullptr) {
+ auto val = params[i].second->evaluate(context);
+ call_context->set(params[i].first, val);
+ }
+ }
+ return body->render(call_context);
+ });
+ macro_context->set(name->get_name(), callable);
+ }
+};
+
+class FilterNode : public TemplateNode {
+ std::shared_ptr<Expression> filter;
+ std::shared_ptr<TemplateNode> body;
+
+public:
+ FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
+ : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {}
+
+ void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+ if (!filter) throw std::runtime_error("FilterNode.filter is null");
+ if (!body) throw std::runtime_error("FilterNode.body is null");
+ auto filter_value = filter->evaluate(context);
+ if (!filter_value.is_callable()) {
+ throw std::runtime_error("Filter must be a callable: " + filter_value.dump());
+ }
+ std::string rendered_body = body->render(context);
+
+ ArgumentsValue filter_args = {{Value(rendered_body)}, {}};
+ auto result = filter_value.call(context, filter_args);
+ out << result.to_str();
+ }
+};
+
+class SetNode : public TemplateNode {
+ std::string ns;
+ std::vector<std::string> var_names;
+ std::shared_ptr<Expression> value;
+public:
+ SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+ : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
+ void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+ if (!value) throw std::runtime_error("SetNode.value is null");
+ if (!ns.empty()) {
+ if (var_names.size() != 1) {
+ throw std::runtime_error("Namespaced set only supports a single variable name");
+ }
+ auto & name = var_names[0];
+ auto ns_value = context->get(ns);
+ if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
+ ns_value.set(name, this->value->evaluate(context));
+ } else {
+ auto val = value->evaluate(context);
+ destructuring_assign(var_names, context, val);
+ }
+ }
+};
+
+class SetTemplateNode : public TemplateNode {
+ std::string name;
+ std::shared_ptr<TemplateNode> template_value;
+public:
+ SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
+ : TemplateNode(location), name(name), template_value(std::move(tv)) {}
+ void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+ if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
+ Value value { template_value->render(context) };
+ context->set(name, value);
+ }
+};
+
+class IfExpr : public Expression {
+ std::shared_ptr<Expression> condition;
+ std::shared_ptr<Expression> then_expr;
+ std::shared_ptr<Expression> else_expr;
+public:
+ IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
+ : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!condition) throw std::runtime_error("IfExpr.condition is null");
+ if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
+ if (condition->evaluate(context).to_bool()) {
+ return then_expr->evaluate(context);
+ }
+ if (else_expr) {
+ return else_expr->evaluate(context);
+ }
+ return nullptr;
+ }
+};
+
+class LiteralExpr : public Expression {
+ Value value;
+public:
+ LiteralExpr(const Location & location, const Value& v)
+ : Expression(location), value(v) {}
+ Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
+};
+
+class ArrayExpr : public Expression {
+ std::vector<std::shared_ptr<Expression>> elements;
+public:
+ ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
+ : Expression(location), elements(std::move(e)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ auto result = Value::array();
+ for (const auto& e : elements) {
+ if (!e) throw std::runtime_error("Array element is null");
+ result.push_back(e->evaluate(context));
+ }
+ return result;
+ }
+};
+
+class DictExpr : public Expression {
+ std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+public:
+ DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
+ : Expression(location), elements(std::move(e)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ auto result = Value::object();
+ for (const auto& [key, value] : elements) {
+ if (!key) throw std::runtime_error("Dict key is null");
+ if (!value) throw std::runtime_error("Dict value is null");
+ result.set(key->evaluate(context), value->evaluate(context));
+ }
+ return result;
+ }
+};
+
+class SliceExpr : public Expression {
+public:
+ std::shared_ptr<Expression> start, end;
+ SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
+ : Expression(location), start(std::move(s)), end(std::move(e)) {}
+ Value do_evaluate(const std::shared_ptr<Context> &) const override {
+ throw std::runtime_error("SliceExpr not implemented");
+ }
+};
+
+class SubscriptExpr : public Expression {
+ std::shared_ptr<Expression> base;
+ std::shared_ptr<Expression> index;
+public:
+ SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
+ : Expression(location), base(std::move(b)), index(std::move(i)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!base) throw std::runtime_error("SubscriptExpr.base is null");
+ if (!index) throw std::runtime_error("SubscriptExpr.index is null");
+ auto target_value = base->evaluate(context);
+ if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
+ auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
+ auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
+ if (target_value.is_string()) {
+ std::string s = target_value.get<std::string>();
+ if (start < 0) start = s.size() + start;
+ if (end < 0) end = s.size() + end;
+ return s.substr(start, end - start);
+ } else if (target_value.is_array()) {
+ if (start < 0) start = target_value.size() + start;
+ if (end < 0) end = target_value.size() + end;
+ auto result = Value::array();
+ for (auto i = start; i < end; ++i) {
+ result.push_back(target_value.at(i));
+ }
+ return result;
+ } else {
+ throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings");
+ }
+ } else {
+ auto index_value = index->evaluate(context);
+ if (target_value.is_null()) {
+ if (auto t = dynamic_cast<VariableExpr*>(base.get())) {
+ throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined"));
+ }
+ throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!");
+ }
+ return target_value.get(index_value);
+ }
+ }
+};
+
+class UnaryOpExpr : public Expression {
+public:
+ enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
+ std::shared_ptr<Expression> expr;
+ Op op;
+ UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
+ : Expression(location), expr(std::move(e)), op(o) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
+ auto e = expr->evaluate(context);
+ switch (op) {
+ case Op::Plus: return e;
+ case Op::Minus: return -e;
+ case Op::LogicalNot: return !e.to_bool();
+ case Op::Expansion:
+ case Op::ExpansionDict:
+ throw std::runtime_error("Expansion operator is only supported in function calls and collections");
+
+ }
+ throw std::runtime_error("Unknown unary operator");
+ }
+};
+
+class BinaryOpExpr : public Expression {
+public:
+ enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
+private:
+ std::shared_ptr<Expression> left;
+ std::shared_ptr<Expression> right;
+ Op op;
+public:
+ BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
+ : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
+ if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
+ auto l = left->evaluate(context);
+
+ auto do_eval = [&](const Value & l) -> Value {
+ if (op == Op::Is || op == Op::IsNot) {
+ auto t = dynamic_cast<VariableExpr*>(right.get());
+ if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable");
+
+ auto eval = [&]() {
+ const auto & name = t->get_name();
+ if (name == "none") return l.is_null();
+ if (name == "boolean") return l.is_boolean();
+ if (name == "integer") return l.is_number_integer();
+ if (name == "float") return l.is_number_float();
+ if (name == "number") return l.is_number();
+ if (name == "string") return l.is_string();
+ if (name == "mapping") return l.is_object();
+ if (name == "iterable") return l.is_iterable();
+ if (name == "sequence") return l.is_array();
+ if (name == "defined") return !l.is_null();
+ throw std::runtime_error("Unknown type for 'is' operator: " + name);
+ };
+ auto value = eval();
+ return Value(op == Op::Is ? value : !value);
+ }
+
+ if (op == Op::And) {
+ if (!l.to_bool()) return Value(false);
+ return right->evaluate(context).to_bool();
+ } else if (op == Op::Or) {
+ if (l.to_bool()) return l;
+ return right->evaluate(context);
+ }
+
+ auto r = right->evaluate(context);
+ switch (op) {
+ case Op::StrConcat: return l.to_str() + r.to_str();
+ case Op::Add: return l + r;
+ case Op::Sub: return l - r;
+ case Op::Mul: return l * r;
+ case Op::Div: return l / r;
+ case Op::MulMul: return std::pow(l.get<double>(), r.get<double>());
+ case Op::DivDiv: return l.get<int64_t>() / r.get<int64_t>();
+ case Op::Mod: return l.get<int64_t>() % r.get<int64_t>();
+ case Op::Eq: return l == r;
+ case Op::Ne: return l != r;
+ case Op::Lt: return l < r;
+ case Op::Gt: return l > r;
+ case Op::Le: return l <= r;
+ case Op::Ge: return l >= r;
+ case Op::In: return (r.is_array() || r.is_object()) && r.contains(l);
+ case Op::NotIn: return !(r.is_array() && r.contains(l));
+ default: break;
+ }
+ throw std::runtime_error("Unknown binary operator");
+ };
+
+ if (l.is_callable()) {
+ return Value::callable([l, do_eval](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+ auto ll = l.call(context, args);
+ return do_eval(ll); //args[0].second);
+ });
+ } else {
+ return do_eval(l);
+ }
+ }
+};
+
+struct ArgumentsExpression {
+ std::vector<std::shared_ptr<Expression>> args;
+ std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
+
+ ArgumentsValue evaluate(const std::shared_ptr<Context> & context) const {
+ ArgumentsValue vargs;
+ for (const auto& arg : this->args) {
+ if (auto un_expr = std::dynamic_pointer_cast<UnaryOpExpr>(arg)) {
+ if (un_expr->op == UnaryOpExpr::Op::Expansion) {
+ auto array = un_expr->expr->evaluate(context);
+ if (!array.is_array()) {
+ throw std::runtime_error("Expansion operator only supported on arrays");
+ }
+ array.for_each([&](Value & value) {
+ vargs.args.push_back(value);
+ });
+ continue;
+ } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) {
+ auto dict = un_expr->expr->evaluate(context);
+ if (!dict.is_object()) {
+ throw std::runtime_error("ExpansionDict operator only supported on objects");
+ }
+ dict.for_each([&](const Value & key) {
+ vargs.kwargs.push_back({key.get<std::string>(), dict.at(key)});
+ });
+ continue;
+ }
+ }
+ vargs.args.push_back(arg->evaluate(context));
+ }
+ for (const auto& [name, value] : this->kwargs) {
+ vargs.kwargs.push_back({name, value->evaluate(context)});
+ }
+ return vargs;
+ }
+};
+
+static std::string strip(const std::string & s) {
+ auto start = s.find_first_not_of(" \t\n\r");
+ if (start == std::string::npos) return "";
+ auto end = s.find_last_not_of(" \t\n\r");
+ return s.substr(start, end - start + 1);
+}
+
+static std::string capitalize(const std::string & s) {
+ if (s.empty()) return s;
+ auto result = s;
+ result[0] = std::toupper(result[0]);
+ return result;
+}
+
+static std::string html_escape(const std::string & s) {
+ std::string result;
+ result.reserve(s.size());
+ for (const auto & c : s) {
+ switch (c) {
+ case '&': result += "&"; break;
+ case '<': result += "<"; break;
+ case '>': result += ">"; break;
+ case '"': result += """; break;
+ case '\'': result += "'"; break;
+ default: result += c; break;
+ }
+ }
+ return result;
+}
+
+class MethodCallExpr : public Expression {
+ std::shared_ptr<Expression> object;
+ std::shared_ptr<VariableExpr> method;
+ ArgumentsExpression args;
+public:
+ MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
+ : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!object) throw std::runtime_error("MethodCallExpr.object is null");
+ if (!method) throw std::runtime_error("MethodCallExpr.method is null");
+ auto obj = object->evaluate(context);
+ auto vargs = args.evaluate(context);
+ if (obj.is_null()) {
+ throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null");
+ }
+ if (obj.is_array()) {
+ if (method->get_name() == "append") {
+ vargs.expectArgs("append method", {1, 1}, {0, 0});
+ obj.push_back(vargs.args[0]);
+ return Value();
+ } else if (method->get_name() == "pop") {
+ vargs.expectArgs("pop method", {0, 1}, {0, 0});
+ return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]);
+ } else if (method->get_name() == "insert") {
+ vargs.expectArgs("insert method", {2, 2}, {0, 0});
+ auto index = vargs.args[0].get<int64_t>();
+ if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method");
+ obj.insert(index, vargs.args[1]);
+ return Value();
+ }
+ } else if (obj.is_object()) {
+ if (method->get_name() == "items") {
+ vargs.expectArgs("items method", {0, 0}, {0, 0});
+ auto result = Value::array();
+ for (const auto& key : obj.keys()) {
+ result.push_back(Value::array({key, obj.at(key)}));
+ }
+ return result;
+ } else if (method->get_name() == "pop") {
+ vargs.expectArgs("pop method", {1, 1}, {0, 0});
+ return obj.pop(vargs.args[0]);
+ } else if (method->get_name() == "get") {
+ vargs.expectArgs("get method", {1, 2}, {0, 0});
+ auto key = vargs.args[0];
+ if (vargs.args.size() == 1) {
+ return obj.contains(key) ? obj.at(key) : Value();
+ } else {
+ return obj.contains(key) ? obj.at(key) : vargs.args[1];
+ }
+ } else if (obj.contains(method->get_name())) {
+ auto callable = obj.at(method->get_name());
+ if (!callable.is_callable()) {
+ throw std::runtime_error("Property '" + method->get_name() + "' is not callable");
+ }
+ return callable.call(context, vargs);
+ }
+ } else if (obj.is_string()) {
+ auto str = obj.get<std::string>();
+ if (method->get_name() == "strip") {
+ vargs.expectArgs("strip method", {0, 0}, {0, 0});
+ return Value(strip(str));
+ } else if (method->get_name() == "capitalize") {
+ vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
+ return Value(capitalize(str));
+ } else if (method->get_name() == "endswith") {
+ vargs.expectArgs("endswith method", {1, 1}, {0, 0});
+ auto suffix = vargs.args[0].get<std::string>();
+ return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
+ } else if (method->get_name() == "title") {
+ vargs.expectArgs("title method", {0, 0}, {0, 0});
+ auto res = str;
+ for (size_t i = 0, n = res.size(); i < n; ++i) {
+ if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]);
+ else res[i] = std::tolower(res[i]);
+ }
+ return res;
+ }
+ }
+ throw std::runtime_error("Unknown method: " + method->get_name());
+ }
+};
+
+class CallExpr : public Expression {
+public:
+ std::shared_ptr<Expression> object;
+ ArgumentsExpression args;
+ CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
+ : Expression(location), object(std::move(obj)), args(std::move(a)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ if (!object) throw std::runtime_error("CallExpr.object is null");
+ auto obj = object->evaluate(context);
+ if (!obj.is_callable()) {
+ throw std::runtime_error("Object is not callable: " + obj.dump(2));
+ }
+ auto vargs = args.evaluate(context);
+ return obj.call(context, vargs);
+ }
+};
+
+class FilterExpr : public Expression {
+ std::vector<std::shared_ptr<Expression>> parts;
+public:
+ FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
+ : Expression(location), parts(std::move(p)) {}
+ Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+ Value result;
+ bool first = true;
+ for (const auto& part : parts) {
+ if (!part) throw std::runtime_error("FilterExpr.part is null");
+ if (first) {
+ first = false;
+ result = part->evaluate(context);
+ } else {
+ if (auto ce = dynamic_cast<CallExpr*>(part.get())) {
+ auto target = ce->object->evaluate(context);
+ ArgumentsValue args = ce->args.evaluate(context);
+ args.args.insert(args.args.begin(), result);
+ result = target.call(context, args);
+ } else {
+ auto callable = part->evaluate(context);
+ ArgumentsValue args;
+ args.args.insert(args.args.begin(), result);
+ result = callable.call(context, args);
+ }
+ }
+ }
+ return result;
+ }
+
+ void prepend(std::shared_ptr<Expression> && e) {
+ parts.insert(parts.begin(), std::move(e));
+ }
+};
+
+class Parser {
+private:
+ using CharIterator = std::string::const_iterator;
+
+ std::shared_ptr<std::string> template_str;
+ CharIterator start, end, it;
+ Options options;
+
+ Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
+ if (!template_str) throw std::runtime_error("Template string is null");
+ start = it = this->template_str->begin();
+ end = this->template_str->end();
+ }
+
+ bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) {
+ if (space_handling == SpaceHandling::Strip) {
+ while (it != end && std::isspace(*it)) ++it;
+ }
+ return true;
+ }
+
+ std::unique_ptr<std::string> parseString() {
+ auto doParse = [&](char quote) -> std::unique_ptr<std::string> {
+ if (it == end || *it != quote) return nullptr;
+ std::string result;
+ bool escape = false;
+ for (++it; it != end; ++it) {
+ if (escape) {
+ escape = false;
+ switch (*it) {
+ case 'n': result += '\n'; break;
+ case 'r': result += '\r'; break;
+ case 't': result += '\t'; break;
+ case 'b': result += '\b'; break;
+ case 'f': result += '\f'; break;
+ case '\\': result += '\\'; break;
+ default:
+ if (*it == quote) {
+ result += quote;
+ } else {
+ result += *it;
+ }
+ break;
+ }
+ } else if (*it == '\\') {
+ escape = true;
+ } else if (*it == quote) {
+ ++it;
+ return std::make_unique<std::string>(std::move(result));
+ } else {
+ result += *it;
+ }
+ }
+ return nullptr;
+ };
+
+ consumeSpaces();
+ if (it == end) return nullptr;
+ if (*it == '"') return doParse('"');
+ if (*it == '\'') return doParse('\'');
+ return nullptr;
+ }
+
+ json parseNumber(CharIterator& it, const CharIterator& end) {
+ auto before = it;
+ consumeSpaces();
+ auto start = it;
+ bool hasDecimal = false;
+ bool hasExponent = false;
+
+ if (it != end && (*it == '-' || *it == '+')) ++it;
+
+ while (it != end) {
+ if (std::isdigit(*it)) {
+ ++it;
+ } else if (*it == '.') {
+ if (hasDecimal) throw std::runtime_error("Multiple decimal points");
+ hasDecimal = true;
+ ++it;
+ } else if (it != start && (*it == 'e' || *it == 'E')) {
+ if (hasExponent) throw std::runtime_error("Multiple exponents");
+ hasExponent = true;
+ ++it;
+ } else {
+ break;
+ }
+ }
+ if (start == it) {
+ it = before;
+ return json(); // No valid characters found
+ }
+
+ std::string str(start, it);
+ try {
+ return json::parse(str);
+ } catch (json::parse_error& e) {
+ throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")");
+ return json();
+ }
+ }
+
+ /** integer, float, bool, string */
+ std::shared_ptr<Value> parseConstant() {
+ auto start = it;
+ consumeSpaces();
+ if (it == end) return nullptr;
+ if (*it == '"' || *it == '\'') {
+ auto str = parseString();
+ if (str) return std::make_shared<Value>(*str);
+ }
+ static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
+ auto token = consumeToken(prim_tok);
+ if (!token.empty()) {
+ if (token == "true" || token == "True") return std::make_shared<Value>(true);
+ if (token == "false" || token == "False") return std::make_shared<Value>(false);
+ if (token == "None") return std::make_shared<Value>(nullptr);
+ throw std::runtime_error("Unknown constant token: " + token);
+ }
+
+ auto number = parseNumber(it, end);
+ if (!number.is_null()) return std::make_shared<Value>(number);
+
+ it = start;
+ return nullptr;
+ }
+
+ class expression_parsing_error : public std::runtime_error {
+ const CharIterator it;
+ public:
+ expression_parsing_error(const std::string & message, const CharIterator & it)
+ : std::runtime_error(message), it(it) {}
+ size_t get_pos(const CharIterator & begin) const {
+ return std::distance(begin, it);
+ }
+ };
+
+ bool peekSymbols(const std::vector<std::string> & symbols) const {
+ for (const auto & symbol : symbols) {
+ if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::vector<std::string> consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+ auto start = it;
+ consumeSpaces(space_handling);
+ std::smatch match;
+ if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+ it += match[0].length();
+ std::vector<std::string> ret;
+ for (size_t i = 0, n = match.size(); i < n; ++i) {
+ ret.push_back(match[i].str());
+ }
+ return ret;
+ }
+ it = start;
+ return {};
+ }
+ std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+ auto start = it;
+ consumeSpaces(space_handling);
+ std::smatch match;
+ if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+ it += match[0].length();
+ return match[0].str();
+ }
+ it = start;
+ return "";
+ }
+
+ std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) {
+ auto start = it;
+ consumeSpaces(space_handling);
+ if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) {
+ it += token.size();
+ return token;
+ }
+ it = start;
+ return "";
+ }
+
+ std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
+ auto left = parseLogicalOr();
+ if (it == end) return left;
+
+ if (!allow_if_expr) return left;
+
+ static std::regex if_tok(R"(if\b)");
+ if (consumeToken(if_tok).empty()) {
+ return left;
+ }
+
+ auto location = get_location();
+ auto [condition, else_expr] = parseIfExpression();
+ return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
+ }
+
+ Location get_location() const {
+ return {template_str, (size_t) std::distance(start, it)};
+ }
+
+ std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
+ auto condition = parseLogicalOr();
+ if (!condition) throw std::runtime_error("Expected condition expression");
+
+ static std::regex else_tok(R"(else\b)");
+ std::shared_ptr<Expression> else_expr;
+ if (!consumeToken(else_tok).empty()) {
+ else_expr = parseExpression();
+ if (!else_expr) throw std::runtime_error("Expected 'else' expression");
+ }
+ return std::pair(std::move(condition), std::move(else_expr));
+ }
+
+ std::shared_ptr<Expression> parseLogicalOr() {
+ auto left = parseLogicalAnd();
+ if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
+
+ static std::regex or_tok(R"(or\b)");
+ auto location = get_location();
+ while (!consumeToken(or_tok).empty()) {
+ auto right = parseLogicalAnd();
+ if (!right) throw std::runtime_error("Expected right side of 'or' expression");
+ left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> parseLogicalNot() {
+ static std::regex not_tok(R"(not\b)");
+ auto location = get_location();
+
+ if (!consumeToken(not_tok).empty()) {
+ auto sub = parseLogicalNot();
+ if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
+ return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
+ }
+ return parseLogicalCompare();
+ }
+
+ std::shared_ptr<Expression> parseLogicalAnd() {
+ auto left = parseLogicalNot();
+ if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
+
+ static std::regex and_tok(R"(and\b)");
+ auto location = get_location();
+ while (!consumeToken(and_tok).empty()) {
+ auto right = parseLogicalNot();
+ if (!right) throw std::runtime_error("Expected right side of 'and' expression");
+ left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> parseLogicalCompare() {
+ auto left = parseStringConcat();
+ if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
+
+ static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
+ static std::regex not_tok(R"(not\b)");
+ std::string op_str;
+ while (!(op_str = consumeToken(compare_tok)).empty()) {
+ auto location = get_location();
+ if (op_str == "is") {
+ auto negated = !consumeToken(not_tok).empty();
+
+ auto identifier = parseIdentifier();
+ if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
+
+ return std::make_shared<BinaryOpExpr>(
+ left->location,
+ std::move(left), std::move(identifier),
+ negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
+ }
+ auto right = parseStringConcat();
+ if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression");
+ BinaryOpExpr::Op op;
+ if (op_str == "==") op = BinaryOpExpr::Op::Eq;
+ else if (op_str == "!=") op = BinaryOpExpr::Op::Ne;
+ else if (op_str == "<") op = BinaryOpExpr::Op::Lt;
+ else if (op_str == ">") op = BinaryOpExpr::Op::Gt;
+ else if (op_str == "<=") op = BinaryOpExpr::Op::Le;
+ else if (op_str == ">=") op = BinaryOpExpr::Op::Ge;
+ else if (op_str == "in") op = BinaryOpExpr::Op::In;
+ else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
+ else throw std::runtime_error("Unknown comparison operator: " + op_str);
+ left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+ }
+ return left;
+ }
+
+ Expression::Parameters parseParameters() {
+ consumeSpaces();
+ if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
+
+ Expression::Parameters result;
+
+ while (it != end) {
+ if (!consumeToken(")").empty()) {
+ return result;
+ }
+ auto expr = parseExpression();
+ if (!expr) throw std::runtime_error("Expected expression in call args");
+
+ if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+ if (!consumeToken("=").empty()) {
+ auto value = parseExpression();
+ if (!value) throw std::runtime_error("Expected expression in for named arg");
+ result.emplace_back(ident->get_name(), std::move(value));
+ } else {
+ result.emplace_back(ident->get_name(), nullptr);
+ }
+ } else {
+ result.emplace_back(std::string(), std::move(expr));
+ }
+ if (consumeToken(",").empty()) {
+ if (consumeToken(")").empty()) {
+ throw std::runtime_error("Expected closing parenthesis in call args");
+ }
+ return result;
+ }
+ }
+ throw std::runtime_error("Expected closing parenthesis in call args");
+ }
+
+ ArgumentsExpression parseCallArgs() {
+ consumeSpaces();
+ if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
+
+ ArgumentsExpression result;
+
+ while (it != end) {
+ if (!consumeToken(")").empty()) {
+ return result;
+ }
+ auto expr = parseExpression();
+ if (!expr) throw std::runtime_error("Expected expression in call args");
+
+ if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+ if (!consumeToken("=").empty()) {
+ auto value = parseExpression();
+ if (!value) throw std::runtime_error("Expected expression in for named arg");
+ result.kwargs.emplace_back(ident->get_name(), std::move(value));
+ } else {
+ result.args.emplace_back(std::move(expr));
+ }
+ } else {
+ result.args.emplace_back(std::move(expr));
+ }
+ if (consumeToken(",").empty()) {
+ if (consumeToken(")").empty()) {
+ throw std::runtime_error("Expected closing parenthesis in call args");
+ }
+ return result;
+ }
+ }
+ throw std::runtime_error("Expected closing parenthesis in call args");
+ }
+
+ std::shared_ptr<VariableExpr> parseIdentifier() {
+ static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
+ auto location = get_location();
+ auto ident = consumeToken(ident_regex);
+ if (ident.empty())
+ return nullptr;
+ return std::make_shared<VariableExpr>(location, ident);
+ }
+
+ std::shared_ptr<Expression> parseStringConcat() {
+ auto left = parseMathPow();
+ if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
+
+ static std::regex concat_tok(R"(~(?!\}))");
+ if (!consumeToken(concat_tok).empty()) {
+ auto right = parseLogicalAnd();
+ if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
+ left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> parseMathPow() {
+ auto left = parseMathPlusMinus();
+ if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
+
+ while (!consumeToken("**").empty()) {
+ auto right = parseMathPlusMinus();
+ if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
+ left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> parseMathPlusMinus() {
+ static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
+
+ auto left = parseMathMulDiv();
+ if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression");
+ std::string op_str;
+ while (!(op_str = consumeToken(plus_minus_tok)).empty()) {
+ auto right = parseMathMulDiv();
+ if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
+ auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
+ left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> parseMathMulDiv() {
+ auto left = parseMathUnaryPlusMinus();
+ if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
+
+ static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))");
+ std::string op_str;
+ while (!(op_str = consumeToken(mul_div_tok)).empty()) {
+ auto right = parseMathUnaryPlusMinus();
+ if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
+ auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
+ : op_str == "**" ? BinaryOpExpr::Op::MulMul
+ : op_str == "/" ? BinaryOpExpr::Op::Div
+ : op_str == "//" ? BinaryOpExpr::Op::DivDiv
+ : BinaryOpExpr::Op::Mod;
+ left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+ }
+
+ if (!consumeToken("|").empty()) {
+ auto expr = parseMathMulDiv();
+ if (auto filter = dynamic_cast<FilterExpr*>(expr.get())) {
+ filter->prepend(std::move(left));
+ return expr;
+ } else {
+ std::vector<std::shared_ptr<Expression>> parts;
+ parts.emplace_back(std::move(left));
+ parts.emplace_back(std::move(expr));
+ return std::make_shared<FilterExpr>(get_location(), std::move(parts));
+ }
+ }
+ return left;
+ }
+
+ std::shared_ptr<Expression> call_func(const std::string & name, ArgumentsExpression && args) const {
+ return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
+ }
+
+ std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
+ static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
+ auto op_str = consumeToken(unary_plus_minus_tok);
+ auto expr = parseExpansion();
+ if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression");
+
+ if (!op_str.empty()) {
+ auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
+ return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
+ }
+ return expr;
+ }
+
+ std::shared_ptr<Expression> parseExpansion() {
+ static std::regex expansion_tok(R"(\*\*?)");
+ auto op_str = consumeToken(expansion_tok);
+ auto expr = parseValueExpression();
+ if (op_str.empty()) return expr;
+ if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression");
+ return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict);
+ }
+
+ std::shared_ptr<Expression> parseValueExpression() {
+ auto parseValue = [&]() -> std::shared_ptr<Expression> {
+ auto location = get_location();
+ auto constant = parseConstant();
+ if (constant) return std::make_shared<LiteralExpr>(location, *constant);
+
+ static std::regex null_regex(R"(null\b)");
+ if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
+
+ auto identifier = parseIdentifier();
+ if (identifier) return identifier;
+
+ auto braced = parseBracedExpressionOrArray();
+ if (braced) return braced;
+
+ auto array = parseArray();
+ if (array) return array;
+
+ auto dictionary = parseDictionary();
+ if (dictionary) return dictionary;
+
+ throw std::runtime_error("Expected value expression");
+ };
+
+ auto value = parseValue();
+
+ while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
+ if (!consumeToken("[").empty()) {
+ std::shared_ptr<Expression> index;
+ if (!consumeToken(":").empty()) {
+ auto slice_end = parseExpression();
+ index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
+ } else {
+ auto slice_start = parseExpression();
+ if (!consumeToken(":").empty()) {
+ consumeSpaces();
+ if (peekSymbols({ "]" })) {
+ index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
+ } else {
+ auto slice_end = parseExpression();
+ index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
+ }
+ } else {
+ index = std::move(slice_start);
+ }
+ }
+ if (!index) throw std::runtime_error("Empty index in subscript");
+ if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
+
+ value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
+ } else if (!consumeToken(".").empty()) {
+ auto identifier = parseIdentifier();
+ if (!identifier) throw std::runtime_error("Expected identifier in subscript");
+
+ consumeSpaces();
+ if (peekSymbols({ "(" })) {
+ auto callParams = parseCallArgs();
+ value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
+ } else {
+ auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
+ value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
+ }
+ }
+ consumeSpaces();
+ }
+
+ if (peekSymbols({ "(" })) {
+ auto location = get_location();
+ auto callParams = parseCallArgs();
+ value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
+ }
+ return value;
+ }
+
+ std::shared_ptr<Expression> parseBracedExpressionOrArray() {
+ if (consumeToken("(").empty()) return nullptr;
+
+ auto expr = parseExpression();
+ if (!expr) throw std::runtime_error("Expected expression in braced expression");
+
+ if (!consumeToken(")").empty()) {
+ return expr; // Drop the parentheses
+ }
+
+ std::vector<std::shared_ptr<Expression>> tuple;
+ tuple.emplace_back(std::move(expr));
+
+ while (it != end) {
+ if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple");
+ auto next = parseExpression();
+ if (!next) throw std::runtime_error("Expected expression in tuple");
+ tuple.push_back(std::move(next));
+
+ if (!consumeToken(")").empty()) {
+ return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
+ }
+ }
+ throw std::runtime_error("Expected closing parenthesis");
+ }
+
+ std::shared_ptr<Expression> parseArray() {
+ if (consumeToken("[").empty()) return nullptr;
+
+ std::vector<std::shared_ptr<Expression>> elements;
+ if (!consumeToken("]").empty()) {
+ return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+ }
+ auto first_expr = parseExpression();
+ if (!first_expr) throw std::runtime_error("Expected first expression in array");
+ elements.push_back(std::move(first_expr));
+
+ while (it != end) {
+ if (!consumeToken(",").empty()) {
+ auto expr = parseExpression();
+ if (!expr) throw std::runtime_error("Expected expression in array");
+ elements.push_back(std::move(expr));
+ } else if (!consumeToken("]").empty()) {
+ return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+ } else {
+ throw std::runtime_error("Expected comma or closing bracket in array");
+ }
+ }
+ throw std::runtime_error("Expected closing bracket");
+ }
+
+ std::shared_ptr<Expression> parseDictionary() {
+ if (consumeToken("{").empty()) return nullptr;
+
+ std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+ if (!consumeToken("}").empty()) {
+ return std::make_shared<DictExpr>(get_location(), std::move(elements));
+ }
+
+ auto parseKeyValuePair = [&]() {
+ auto key = parseExpression();
+ if (!key) throw std::runtime_error("Expected key in dictionary");
+ if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
+ auto value = parseExpression();
+ if (!value) throw std::runtime_error("Expected value in dictionary");
+ elements.emplace_back(std::pair(std::move(key), std::move(value)));
+ };
+
+ parseKeyValuePair();
+
+ while (it != end) {
+ if (!consumeToken(",").empty()) {
+ parseKeyValuePair();
+ } else if (!consumeToken("}").empty()) {
+ return std::make_shared<DictExpr>(get_location(), std::move(elements));
+ } else {
+ throw std::runtime_error("Expected comma or closing brace in dictionary");
+ }
+ }
+ throw std::runtime_error("Expected closing brace");
+ }
+
+ SpaceHandling parsePreSpace(const std::string& s) const {
+ if (s == "-")
+ return SpaceHandling::Strip;
+ return SpaceHandling::Keep;
+ }
+
+ SpaceHandling parsePostSpace(const std::string& s) const {
+ if (s == "-") return SpaceHandling::Strip;
+ return SpaceHandling::Keep;
+ }
+
+ using TemplateTokenVector = std::vector<std::unique_ptr<TemplateToken>>;
+ using TemplateTokenIterator = TemplateTokenVector::const_iterator;
+
+ std::vector<std::string> parseVarNames() {
+ static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
+
+ std::vector<std::string> group;
+ if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
+ std::vector<std::string> varnames;
+ std::istringstream iss(group[1]);
+ std::string varname;
+ while (std::getline(iss, varname, ',')) {
+ varnames.push_back(strip(varname));
+ }
+ return varnames;
+ }
+
+ std::runtime_error unexpected(const TemplateToken & token) const {
+ return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type)
+ + error_location_suffix(*template_str, token.location.pos));
+ }
+ std::runtime_error unterminated(const TemplateToken & token) const {
+ return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type)
+ + error_location_suffix(*template_str, token.location.pos));
+ }
+
+ TemplateTokenVector tokenize() {
+ static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
+ static std::regex expr_open_regex(R"(\{\{([-~])?)");
+ static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
+ static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
+ static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
+ static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
+ static std::regex block_close_regex(R"(\s*([-~])?%\})");
+
+ TemplateTokenVector tokens;
+ std::vector<std::string> group;
+ std::string text;
+ std::smatch match;
+
+ try {
+ while (it != end) {
+ auto location = get_location();
+
+ if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
+ auto pre_space = parsePreSpace(group[1]);
+ auto content = group[2];
+ auto post_space = parsePostSpace(group[3]);
+ tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
+ } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
+ auto pre_space = parsePreSpace(group[1]);
+ auto expr = parseExpression();
+
+ if ((group = consumeTokenGroups(expr_close_regex)).empty()) {
+ throw std::runtime_error("Expected closing expression tag");
+ }
+
+ auto post_space = parsePostSpace(group[1]);
+ tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
+ } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
+ auto pre_space = parsePreSpace(group[1]);
+
+ std::string keyword;
+
+ auto parseBlockClose = [&]() -> SpaceHandling {
+ if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag");
+ return parsePostSpace(group[1]);
+ };
+
+ if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
+
+ if (keyword == "if") {
+ auto condition = parseExpression();
+ if (!condition) throw std::runtime_error("Expected condition in if block");
+
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
+ } else if (keyword == "elif") {
+ auto condition = parseExpression();
+ if (!condition) throw std::runtime_error("Expected condition in elif block");
+
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
+ } else if (keyword == "else") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "endif") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "for") {
+ static std::regex recursive_tok(R"(recursive\b)");
+ static std::regex if_tok(R"(if\b)");
+
+ auto varnames = parseVarNames();
+ static std::regex in_tok(R"(in\b)");
+ if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block");
+ auto iterable = parseExpression(/* allow_if_expr = */ false);
+ if (!iterable) throw std::runtime_error("Expected iterable in for block");
+
+ std::shared_ptr<Expression> condition;
+ if (!consumeToken(if_tok).empty()) {
+ condition = parseExpression();
+ }
+ auto recursive = !consumeToken(recursive_tok).empty();
+
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
+ } else if (keyword == "endfor") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "generation") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "endgeneration") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "set") {
+ static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
+
+ std::string ns;
+ std::vector<std::string> var_names;
+ std::shared_ptr<Expression> value;
+ if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
+ ns = group[1];
+ var_names.push_back(group[2]);
+
+ if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
+
+ value = parseExpression();
+ if (!value) throw std::runtime_error("Expected value in set block");
+ } else {
+ var_names = parseVarNames();
+
+ if (!consumeToken("=").empty()) {
+ value = parseExpression();
+ if (!value) throw std::runtime_error("Expected value in set block");
+ }
+ }
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
+ } else if (keyword == "endset") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "macro") {
+ auto macroname = parseIdentifier();
+ if (!macroname) throw std::runtime_error("Expected macro name in macro block");
+ auto params = parseParameters();
+
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
+ } else if (keyword == "endmacro") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "filter") {
+ auto filter = parseExpression();
+ if (!filter) throw std::runtime_error("Expected expression in filter block");
+
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
+ } else if (keyword == "endfilter") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
+ } else if (keyword == "break" || keyword == "continue") {
+ auto post_space = parseBlockClose();
+ tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
+ } else {
+ throw std::runtime_error("Unexpected block: " + keyword);
+ }
+ } else if (std::regex_search(it, end, match, non_text_open_regex)) {
+ if (!match.position()) {
+ if (match[0] != "{#")
+ throw std::runtime_error("Internal error: Expected a comment");
+ throw std::runtime_error("Missing end of comment tag");
+ }
+ auto text_end = it + match.position();
+ text = std::string(it, text_end);
+ it = text_end;
+ tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+ } else {
+ text = std::string(it, end);
+ it = end;
+ tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+ }
+ }
+ return tokens;
+ } catch (const std::exception & e) {
+ throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it)));
+ }
+ }
+
+ std::shared_ptr<TemplateNode> parseTemplate(
+ const TemplateTokenIterator & begin,
+ TemplateTokenIterator & it,
+ const TemplateTokenIterator & end,
+ bool fully = false) const {
+ std::vector<std::shared_ptr<TemplateNode>> children;
+ while (it != end) {
+ const auto start = it;
+ const auto & token = *(it++);
+ if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
+ std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+ cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
+
+ while (it != end && (*it)->type == TemplateToken::Type::Elif) {
+ auto elif_token = dynamic_cast<ElifTemplateToken*>((*(it++)).get());
+ cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end));
+ }
+
+ if (it != end && (*it)->type == TemplateToken::Type::Else) {
+ cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end));
+ }
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
+ throw unterminated(**start);
+ }
+ children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
+ } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
+ auto body = parseTemplate(begin, it, end);
+ auto else_body = std::shared_ptr<TemplateNode>();
+ if (it != end && (*it)->type == TemplateToken::Type::Else) {
+ else_body = parseTemplate(begin, ++it, end);
+ }
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
+ throw unterminated(**start);
+ }
+ children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
+ } else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
+ auto body = parseTemplate(begin, it, end);
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
+ throw unterminated(**start);
+ }
+ // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
+ children.emplace_back(std::move(body));
+ } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
+ SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
+ SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
+
+ auto text = text_token->text;
+ if (post_space == SpaceHandling::Strip) {
+ static std::regex trailing_space_regex(R"(\s+$)");
+ text = std::regex_replace(text, trailing_space_regex, "");
+ } else if (options.lstrip_blocks && it != end) {
+ auto i = text.size();
+ while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
+ if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
+ text.resize(i);
+ }
+ }
+ if (pre_space == SpaceHandling::Strip) {
+ static std::regex leading_space_regex(R"(^\s+)");
+ text = std::regex_replace(text, leading_space_regex, "");
+ } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
+ if (text.length() > 0 && text[0] == '\n') {
+ text.erase(0, 1);
+ }
+ }
+ if (it == end && !options.keep_trailing_newline) {
+ auto i = text.size();
+ if (i > 0 && text[i - 1] == '\n') {
+ i--;
+ if (i > 0 && text[i - 1] == '\r') i--;
+ text.resize(i);
+ }
+ }
+ children.emplace_back(std::make_shared<TextNode>(token->location, text));
+ } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
+ children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
+ } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
+ if (set_token->value) {
+ children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
+ } else {
+ auto value_template = parseTemplate(begin, it, end);
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
+ throw unterminated(**start);
+ }
+ if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
+ if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
+ auto & name = set_token->var_names[0];
+ children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
+ }
+ } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
+ auto body = parseTemplate(begin, it, end);
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
+ throw unterminated(**start);
+ }
+ children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
+ } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
+ auto body = parseTemplate(begin, it, end);
+ if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
+ throw unterminated(**start);
+ }
+ children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
+ } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
+ // Ignore comments
+ } else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
+ children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
+ } else if (dynamic_cast<EndForTemplateToken*>(token.get())
+ || dynamic_cast<EndSetTemplateToken*>(token.get())
+ || dynamic_cast<EndMacroTemplateToken*>(token.get())
+ || dynamic_cast<EndFilterTemplateToken*>(token.get())
+ || dynamic_cast<EndIfTemplateToken*>(token.get())
+ || dynamic_cast<ElseTemplateToken*>(token.get())
+ || dynamic_cast<EndGenerationTemplateToken*>(token.get())
+ || dynamic_cast<ElifTemplateToken*>(token.get())) {
+ it--; // unconsume the token
+ break; // exit the loop
+ } else {
+ throw unexpected(**(it-1));
+ }
+ }
+ if (fully && it != end) {
+ throw unexpected(**it);
+ }
+ if (children.empty()) {
+ return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
+ } else if (children.size() == 1) {
+ return std::move(children[0]);
+ } else {
+ return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
+ }
+ }
+
+public:
+
+ static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
+ Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
+ auto tokens = parser.tokenize();
+ TemplateTokenIterator begin = tokens.begin();
+ auto it = begin;
+ TemplateTokenIterator end = tokens.end();
+ return parser.parseTemplate(begin, it, end, /* full= */ true);
+ }
+};
+
+static Value simple_function(const std::string & fn_name, const std::vector<std::string> & params, const std::function<Value(const std::shared_ptr<Context> &, Value & args)> & fn) {
+ std::map<std::string, size_t> named_positions;
+ for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i;
+
+ return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) -> Value {
+ auto args_obj = Value::object();
+ std::vector<bool> provided_args(params.size());
+ for (size_t i = 0, n = args.args.size(); i < n; i++) {
+ auto & arg = args.args[i];
+ if (i < params.size()) {
+ args_obj.set(params[i], arg);
+ provided_args[i] = true;
+ } else {
+ throw std::runtime_error("Too many positional params for " + fn_name);
+ }
+ }
+ for (auto & [name, value] : args.kwargs) {
+ auto named_pos_it = named_positions.find(name);
+ if (named_pos_it == named_positions.end()) {
+ throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
+ }
+ provided_args[named_pos_it->second] = true;
+ args_obj.set(name, value);
+ }
+ return fn(context, args_obj);
+ });
+}
+
+inline std::shared_ptr<Context> Context::builtins() {
+ auto globals = Value::object();
+
+ globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ throw std::runtime_error(args.at("message").get<std::string>());
+ }));
+ globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
+ return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true));
+ }));
+ globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto items = Value::array();
+ if (args.contains("object")) {
+ auto & obj = args.at("object");
+ if (obj.is_string()) {
+ auto json_obj = json::parse(obj.get<std::string>());
+ for (const auto & kv : json_obj.items()) {
+ items.push_back(Value::array({kv.key(), kv.value()}));
+ }
+ } else if (!obj.is_null()) {
+ for (auto & key : obj.keys()) {
+ items.push_back(Value::array({key, obj.at(key)}));
+ }
+ }
+ }
+ return items;
+ }));
+ globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto items = args.at("items");
+ if (!items.is_array()) throw std::runtime_error("object is not a list");
+ if (items.size() == 0) return Value();
+ return items.at(items.size() - 1);
+ }));
+ globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto & text = args.at("text");
+ return text.is_null() ? text : Value(strip(text.get<std::string>()));
+ }));
+ globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto text = args.at("text");
+ if (text.is_null()) return text;
+ std::string res;
+ auto str = text.get<std::string>();
+ std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower);
+ return Value(res);
+ }));
+ globals.set("default", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+ args.expectArgs("default", {2, 3}, {0, 1});
+ auto & value = args.args[0];
+ auto & default_value = args.args[1];
+ bool boolean = false;
+ if (args.args.size() == 3) {
+ boolean = args.args[2].get<bool>();
+ } else {
+ Value bv = args.get_named("boolean");
+ if (!bv.is_null()) {
+ boolean = bv.get<bool>();
+ }
+ }
+ return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value;
+ }));
+ auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+ return Value(html_escape(args.at("text").get<std::string>()));
+ });
+ globals.set("e", escape);
+ globals.set("escape", escape);
+ globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto sep = args.get<std::string>("sep", "");
+ auto first = std::make_shared<bool>(true);
+ return simple_function("", {}, [sep, first](const std::shared_ptr<Context> &, const Value &) -> Value {
+ if (*first) {
+ *first = false;
+ return "";
+ }
+ return sep;
+ });
+ return Value(html_escape(args.at("text").get<std::string>()));
+ }));
+ globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+ return Value((int64_t) args.at("items").size());
+ }));
+ globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr<Context> &, Value & args) {
+ if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)");
+ auto & value = args.at("value");
+ auto keys = value.keys();
+ std::sort(keys.begin(), keys.end());
+ auto res = Value::array();
+ for (auto & key : keys) {
+ res.push_back(Value::array({key, value.at(key)}));
+ }
+ return res;
+ }));
+ globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto do_join = [](Value & items, const std::string & sep) {
+ if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+ std::ostringstream oss;
+ auto first = true;
+ for (size_t i = 0, n = items.size(); i < n; ++i) {
+ if (first) first = false;
+ else oss << sep;
+ oss << items.at(i).to_str();
+ }
+ return Value(oss.str());
+ };
+ auto sep = args.get<std::string>("d", "");
+ if (args.contains("items")) {
+ auto & items = args.at("items");
+ return do_join(items, sep);
+ } else {
+ return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr<Context> &, Value & args) {
+ auto & items = args.at("items");
+ if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump());
+ return do_join(items, sep);
+ });
+ }
+ }));
+ globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+ auto ns = Value::object();
+ args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits<size_t>::max)()});
+ for (auto & [name, value] : args.kwargs) {
+ ns.set(name, value);
+ }
+ return ns;
+ }));
+ auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ return args.at("actual") == args.at("expected");
+ });
+ globals.set("equalto", equalto);
+ globals.set("==", equalto);
+ globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ auto & items = args.at("items");
+ return (int64_t) items.size();
+ }));
+ globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ return args.at("value").to_str();
+ }));
+ globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ return args.at("value").to_str();
+ }));
+ globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ return args.at("value").to_int();
+ }));
+ globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ auto & items = args.at("items");
+ if (!items.is_array()) throw std::runtime_error("object is not iterable");
+ return items;
+ }));
+ globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+ auto & items = args.at("items");
+ if (!items.is_array()) throw std::runtime_error("object is not iterable");
+ std::unordered_set<Value> seen;
+ auto result = Value::array();
+ for (size_t i = 0, n = items.size(); i < n; i++) {
+ auto pair = seen.insert(items.at(i));
+ if (pair.second) {
+ result.push_back(items.at(i));
+ }
+ }
+ return result;
+ }));
+ auto make_filter = [](const Value & filter, Value & extra_args) -> Value {
+ return simple_function("", { "value" }, [=](const std::shared_ptr<Context> & context, Value & args) {
+ auto & value = args.at("value");
+ ArgumentsValue actual_args;
+ actual_args.args.emplace_back(value);
+ for (size_t i = 0, n = extra_args.size(); i < n; i++) {
+ actual_args.args.emplace_back(extra_args.at(i));
+ }
+ return filter.call(context, actual_args);
+ });
+ };
+ auto select_or_reject = [make_filter](bool is_select) {
+ return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+ args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+ auto & items = args.args[0];
+ if (items.is_null())
+ return Value::array();
+ if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+
+ auto filter_fn = context->get(args.args[1]);
+ if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+
+ auto filter_args = Value::array();
+ for (size_t i = 2, n = args.args.size(); i < n; i++) {
+ filter_args.push_back(args.args[i]);
+ }
+ auto filter = make_filter(filter_fn, filter_args);
+
+ auto res = Value::array();
+ for (size_t i = 0, n = items.size(); i < n; i++) {
+ auto & item = items.at(i);
+ ArgumentsValue filter_args;
+ filter_args.args.emplace_back(item);
+ auto pred_res = filter.call(context, filter_args);
+ if (pred_res.to_bool() == (is_select ? true : false)) {
+ res.push_back(item);
+ }
+ }
+ return res;
+ });
+ };
+ globals.set("select", select_or_reject(/* is_select= */ true));
+ globals.set("reject", select_or_reject(/* is_select= */ false));
+ globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+ auto res = Value::array();
+ if (args.args.size() == 1 &&
+ ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) {
+ auto & items = args.args[0];
+ auto attr_name = args.get_named("attribute");
+ auto default_value = args.get_named("default");
+ for (size_t i = 0, n = items.size(); i < n; i++) {
+ auto & item = items.at(i);
+ auto attr = item.get(attr_name);
+ res.push_back(attr.is_null() ? default_value : attr);
+ }
+ } else if (args.kwargs.empty() && args.args.size() >= 2) {
+ auto fn = context->get(args.args[1]);
+ if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+ ArgumentsValue filter_args { {Value()}, {} };
+ for (size_t i = 2, n = args.args.size(); i < n; i++) {
+ filter_args.args.emplace_back(args.args[i]);
+ }
+ for (size_t i = 0, n = args.args[0].size(); i < n; i++) {
+ auto & item = args.args[0].at(i);
+ filter_args.args[0] = item;
+ res.push_back(fn.call(context, filter_args));
+ }
+ } else {
+ throw std::runtime_error("Invalid or unsupported arguments for map");
+ }
+ return res;
+ }));
+ globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr<Context> &, Value & args) {
+ auto text = args.at("text").get<std::string>();
+ auto first = args.get<bool>("first", false);
+ std::string out;
+ std::string indent(args.get<int64_t>("indent", 0), ' ');
+ std::istringstream iss(text);
+ std::string line;
+ auto is_first = true;
+ while (std::getline(iss, line, '\n')) {
+ auto needs_indent = !is_first || first;
+ if (is_first) is_first = false;
+ else out += "\n";
+ if (needs_indent) out += indent;
+ out += line;
+ }
+ if (!text.empty() && text.back() == '\n') out += "\n";
+ return out;
+ }));
+ auto select_or_reject_attr = [](bool is_select) {
+ return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+ args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+ auto & items = args.args[0];
+ if (items.is_null())
+ return Value::array();
+ if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump());
+ auto attr_name = args.args[1].get<std::string>();
+
+ bool has_test = false;
+ Value test_fn;
+ ArgumentsValue test_args {{Value()}, {}};
+ if (args.args.size() >= 3) {
+ has_test = true;
+ test_fn = context->get(args.args[2]);
+ if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
+ for (size_t i = 3, n = args.args.size(); i < n; i++) {
+ test_args.args.emplace_back(args.args[i]);
+ }
+ test_args.kwargs = args.kwargs;
+ }
+
+ auto res = Value::array();
+ for (size_t i = 0, n = items.size(); i < n; i++) {
+ auto & item = items.at(i);
+ auto attr = item.get(attr_name);
+ if (has_test) {
+ test_args.args[0] = attr;
+ if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
+ res.push_back(item);
+ }
+ } else {
+ res.push_back(attr);
+ }
+ }
+ return res;
+ });
+ };
+ globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
+ globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
+ globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+ std::vector<int64_t> startEndStep(3);
+ std::vector<bool> param_set(3);
+ if (args.args.size() == 1) {
+ startEndStep[1] = args.args[0].get<int64_t>();
+ param_set[1] = true;
+ } else {
+ for (size_t i = 0; i < args.args.size(); i++) {
+ auto & arg = args.args[i];
+ auto v = arg.get<int64_t>();
+ startEndStep[i] = v;
+ param_set[i] = true;
+ }
+ }
+ for (auto & [name, value] : args.kwargs) {
+ size_t i;
+ if (name == "start") i = 0;
+ else if (name == "end") i = 1;
+ else if (name == "step") i = 2;
+ else throw std::runtime_error("Unknown argument " + name + " for function range");
+
+ if (param_set[i]) {
+ throw std::runtime_error("Duplicate argument " + name + " for function range");
+ }
+ startEndStep[i] = value.get<int64_t>();
+ param_set[i] = true;
+ }
+ if (!param_set[1]) {
+ throw std::runtime_error("Missing required argument 'end' for function range");
+ }
+ int64_t start = param_set[0] ? startEndStep[0] : 0;
+ int64_t end = startEndStep[1];
+ int64_t step = param_set[2] ? startEndStep[2] : 1;
+
+ auto res = Value::array();
+ if (step > 0) {
+ for (int64_t i = start; i < end; i += step) {
+ res.push_back(Value(i));
+ }
+ } else {
+ for (int64_t i = start; i > end; i += step) {
+ res.push_back(Value(i));
+ }
+ }
+ return res;
+ }));
+
+ return std::make_shared<Context>(std::move(globals));
+}
+
+inline std::shared_ptr<Context> Context::make(Value && values, const std::shared_ptr<Context> & parent) {
+ return std::make_shared<Context>(values.is_null() ? Value::object() : std::move(values), parent);
+}
+
+} // namespace minja
#include "log.h"
#include "sampling.h"
#include "llama.h"
-#include "chat-template.hpp"
+#include "chat.h"
#include <cstdio>
#include <cstring>
}
const llama_vocab * vocab = llama_model_get_vocab(model);
- auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
+ auto chat_templates = common_chat_templates_init(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
}
// auto enable conversation mode if chat template is available
- const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
+ const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
- LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
+ LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
} else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
- common_chat_msg new_msg{role, content, {}};
- auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
- chat_msgs.push_back({role, content, {}});
+ common_chat_msg new_msg;
+ new_msg.role = role;
+ new_msg.content = content;
+ auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
+ chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
// check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl);
- if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
- if (params.interactive) {
- is_interacting = true;
+ for (auto token : antiprompt_token) {
+ if (token == last_token) {
+ if (params.interactive) {
+ is_interacting = true;
+ }
+ is_antiprompt = true;
+ break;
}
- is_antiprompt = true;
}
if (is_antiprompt) {
#include <string>
#include <vector>
-#include "chat-template.hpp"
+#include "chat.h"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
- std::vector<llama_chat_message> messages;
+ std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
}
// Function to apply the chat template and resize `formatted` if needed
-static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
- if (use_jinja) {
- json messages = json::array();
- for (const auto & msg : llama_data.messages) {
- messages.push_back({
- {"role", msg.role},
- {"content", msg.content},
- });
- }
- try {
- minja::chat_template_inputs tmpl_inputs;
- tmpl_inputs.messages = messages;
- tmpl_inputs.add_generation_prompt = append;
-
- minja::chat_template_options tmpl_opts;
- tmpl_opts.use_bos_token = false;
- tmpl_opts.use_eos_token = false;
-
- auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
- llama_data.fmtted.resize(result.size() + 1);
- memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
- return result.size();
- } catch (const std::exception & e) {
- printe("failed to render the chat template: %s\n", e.what());
- return -1;
- }
- }
- int result = llama_chat_apply_template(
- tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
- append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
- if (append && result > static_cast<int>(llama_data.fmtted.size())) {
- llama_data.fmtted.resize(result);
- result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
- llama_data.messages.size(), append, llama_data.fmtted.data(),
- llama_data.fmtted.size());
- }
-
- return result;
+static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
+ common_chat_templates_inputs inputs;
+ for (const auto & msg : llama_data.messages) {
+ common_chat_msg cmsg;
+ cmsg.role = msg.role;
+ cmsg.content = msg.content;
+ inputs.messages.push_back(cmsg);
+ }
+ inputs.add_generation_prompt = append;
+ inputs.use_jinja = use_jinja;
+
+ auto chat_params = common_chat_templates_apply(tmpls, inputs);
+ // TODO: use other params for tool calls.
+ auto result = chat_params.prompt;
+ llama_data.fmtted.resize(result.size() + 1);
+ memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
+ return result.size();
}
// Function to tokenize the prompt
}
// Helper function to apply the chat template and handle errors
-static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
- const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
+static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
+ const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
- auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
- GGML_ASSERT(chat_templates.template_default);
+ auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
- if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1;
}
}
add_message("assistant", response, llama_data);
- if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
}
// process "json_schema" and "grammar"
- if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
- throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
- }
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
- common_chat_templates chat_templates;
+ common_chat_templates_ptr chat_templates;
~server_context() {
// Clear any sampling context
llama_init_dft.context.reset();
}
- if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
+ chat_templates = common_chat_templates_init(model, params_base.chat_template);
+ try {
+ common_chat_format_example(chat_templates.get(), params.use_jinja);
+ } catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
- chat_templates = common_chat_templates_from_model(model, "chatml");
- } else {
- chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
+ chat_templates = common_chat_templates_init(model, "chatml");
}
- GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
- bool validate_builtin_chat_template(bool use_jinja) const {
- llama_chat_message chat[] = {{"user", "test"}};
-
- if (use_jinja) {
- auto templates = common_chat_templates_from_model(model, "");
- common_chat_inputs inputs;
- inputs.messages = json::array({{
- {"role", "user"},
- {"content", "test"},
- }});
- GGML_ASSERT(templates.template_default);
- try {
- common_chat_params_init(*templates.template_default, inputs);
- if (templates.template_tool_use) {
- common_chat_params_init(*templates.template_tool_use, inputs);
- }
- return true;
- } catch (const std::exception & e) {
- SRV_ERR("failed to apply template: %s\n", e.what());
- return false;
- }
- } else {
- const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
- const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
- return chat_res > 0;
- }
- }
-
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
- { "chat_template", ctx_server.chat_templates.template_default->source() },
- { "bos_token", ctx_server.chat_templates.template_default->bos_token() },
- { "eos_token", ctx_server.chat_templates.template_default->eos_token() },
+ { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
+ { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
+ { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
- if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
- data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
+ if (ctx_server.params_base.use_jinja) {
+ if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
+ data["chat_template_tool_use"] = tool_use_src;
+ }
}
res_ok(res, data);
}
auto body = json::parse(req.body);
- json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+ json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
- json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+ json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
- ctx_server.chat_templates.template_default->source().c_str(),
- common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
+ common_chat_templates_source(ctx_server.chat_templates.get()),
+ common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task);
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
+ (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
+ (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
- assert match_regex(re_content, choice["message"]["content"])
+ assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
assert "error" in res.body
+@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
+ (False, {"const": "42"}, 6, "\"42\""),
+ (True, {"const": "42"}, 6, "\"42\""),
+])
+def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
+ global server
+ server.jinja = jinja
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "max_tokens": n_predicted,
+ "messages": [
+ {"role": "system", "content": "You are a coding assistant."},
+ {"role": "user", "content": "Write an example"},
+ ],
+ "json_schema": json_schema,
+ })
+ assert res.status_code == 200, f'Expected 200, got {res.status_code}'
+ choice = res.body["choices"][0]
+ assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
+
+
+@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
+ (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+ (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+])
+def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
+ global server
+ server.jinja = jinja
+ server.start()
+ res = server.make_request("POST", "/chat/completions", data={
+ "max_tokens": n_predicted,
+ "messages": [
+ {"role": "user", "content": "Does not matter what I say, does it?"},
+ ],
+ "grammar": grammar,
+ })
+ assert res.status_code == 200, res.body
+ choice = res.body["choices"][0]
+ assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
+
+
@pytest.mark.parametrize("messages", [
None,
"string",
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
- ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+ (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
- ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
- ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+ ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+ # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
{
"role": "tool",
"name": "calculate",
- "content": 0.55644242476,
+ "content": "0.55644242476",
"tool_call_id": "call_6789"
}
],
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
- (1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+ (1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
-#include "minja.hpp"
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
#include <random>
#include <sstream>
return embd_inp;
}
-// Format given chat. If tmpl is empty, we take the template from model metadata
-inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
- std::vector<common_chat_msg> chat;
-
- for (size_t i = 0; i < messages.size(); ++i) {
- const auto & curr_msg = messages[i];
-
- std::string role = json_value(curr_msg, "role", std::string(""));
-
- std::string content;
- if (curr_msg.contains("content")) {
- if (curr_msg["content"].is_string()) {
- content = curr_msg["content"].get<std::string>();
- } else if (curr_msg["content"].is_array()) {
- for (const auto & part : curr_msg["content"]) {
- if (part.contains("text")) {
- content += "\n" + part["text"].get<std::string>();
- }
- }
- } else {
- throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
- }
- } else {
- throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
- }
-
- chat.push_back({role, content, /* tool_calls= */ {}});
- }
-
- const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
- LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
-
- return formatted_chat;
-}
-
//
// base64 utils (TODO: move to common in the future)
//
const json & body, /* openai api json semantics */
bool use_jinja,
common_reasoning_format reasoning_format,
- const common_chat_templates & chat_templates)
+ const struct common_chat_templates * tmpls)
{
json llama_params;
- const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
- ? *chat_templates.template_tool_use
- : *chat_templates.template_default;
auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false);
llama_params["stop"] = json_value(body, "stop", json::array());
}
+ auto json_schema = json_value(body, "json_schema", json());
+ auto grammar = json_value(body, "grammar", std::string());
+ if (!json_schema.is_null() && !grammar.empty()) {
+ throw std::runtime_error("Cannot use both json_schema and grammar");
+ }
+
// Handle "response_format" field
if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
- llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+ json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") {
json json_schema = json_value(response_format, "json_schema", json::object());
- llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
+ json_schema = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
+ common_chat_templates_inputs inputs;
+ inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
+ inputs.tools = common_chat_tools_parse_oaicompat(tools);
+ inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+ inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
+ inputs.grammar = grammar;
+ inputs.add_generation_prompt = true;
+ inputs.use_jinja = use_jinja;
+ inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
+ inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+ if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
+ throw std::runtime_error("Cannot use custom grammar constraints with tools.");
+ }
+
// Apply chat template to the list of messages
- if (use_jinja) {
- auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
- if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
- throw std::runtime_error("Invalid tool_choice: " + tool_choice);
- }
- if (tool_choice != "none" && llama_params.contains("grammar")) {
- throw std::runtime_error("Cannot use custom grammar constraints with tools.");
- }
- common_chat_inputs inputs;
- inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
- inputs.messages = body.at("messages");
- inputs.tools = tools;
- inputs.tool_choice = tool_choice;
- inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
- if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
- LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
- inputs.parallel_tool_calls = false;
- }
- inputs.stream = stream;
- // TODO: support mixing schema w/ tools beyond generic format.
- inputs.json_schema = json_value(llama_params, "json_schema", json());
- auto chat_params = common_chat_params_init(tmpl, inputs);
-
- llama_params["chat_format"] = static_cast<int>(chat_params.format);
- llama_params["prompt"] = chat_params.prompt;
- llama_params["grammar"] = chat_params.grammar;
- llama_params["grammar_lazy"] = chat_params.grammar_lazy;
- auto grammar_triggers = json::array();
- for (const auto & trigger : chat_params.grammar_triggers) {
- grammar_triggers.push_back({
- {"word", trigger.word},
- {"at_start", trigger.at_start},
- });
- }
- llama_params["grammar_triggers"] = grammar_triggers;
- llama_params["preserved_tokens"] = chat_params.preserved_tokens;
- for (const auto & stop : chat_params.additional_stops) {
- llama_params["stop"].push_back(stop);
- }
- } else {
- llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
+ auto chat_params = common_chat_templates_apply(tmpls, inputs);
+
+ llama_params["chat_format"] = static_cast<int>(chat_params.format);
+ llama_params["prompt"] = chat_params.prompt;
+ llama_params["grammar"] = chat_params.grammar;
+ llama_params["grammar_lazy"] = chat_params.grammar_lazy;
+ auto grammar_triggers = json::array();
+ for (const auto & trigger : chat_params.grammar_triggers) {
+ grammar_triggers.push_back({
+ {"word", trigger.word},
+ {"at_start", trigger.at_start},
+ });
+ }
+ llama_params["grammar_triggers"] = grammar_triggers;
+ llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+ for (const auto & stop : chat_params.additional_stops) {
+ llama_params["stop"].push_back(stop);
}
// Handle "n" field
#include <string>
#include <vector>
#include <sstream>
+#include <regex>
#undef NDEBUG
#include <cassert>
#include "llama.h"
#include "common.h"
-#include "chat-template.hpp"
+#include "chat.h"
static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
#endif
}
+static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
+ common_chat_msg msg;
+ msg.role = role;
+ msg.content = content;
+ return msg;
+}
+
int main(void) {
std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"},
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "",
- /* .bos_token= */ "",
+ /* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
{
{
/* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
- /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
- /* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+ /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+ /* .expected_output_jinja= */ "",
/* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
/* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
- /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
+ /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
}
}
- json messages = json::array();
+ std::vector<common_chat_msg> messages;
for (const auto & msg : conversation) {
- messages.push_back({
- {"role", msg.role},
- {"content", msg.content},
- });
+ messages.push_back(simple_msg(msg.role, msg.content));
}
for (const auto & test_case : test_cases) {
if (!test_case.supported_with_jinja) {
}
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try {
- minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
- auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
+ common_chat_templates_inputs inputs;
+ inputs.use_jinja = true;
+ inputs.messages = messages;
+ inputs.add_generation_prompt = add_generation_prompt;
+ auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
+ output = normalize_newlines(output);
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str());
// test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<common_chat_msg> chat2;
- common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
+ auto sys_msg = simple_msg("system", "You are a helpful assistant");
auto fmt_sys = [&](std::string tmpl_str) {
- minja::chat_template tmpl(tmpl_str, "", "");
- auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
+ auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
// test llama_chat_format_single for user message
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
- chat2.push_back({"system", "You are a helpful assistant", {}});
- chat2.push_back({"user", "Hello", {}});
- chat2.push_back({"assistant", "I am assistant", {}});
- common_chat_msg new_msg{"user", "How are you", {}};
+ chat2.push_back(simple_msg("system", "You are a helpful assistant"));
+ chat2.push_back(simple_msg("user", "Hello"));
+ chat2.push_back(simple_msg("assistant", "I am assistant"));
+ auto new_msg = simple_msg("user", "How are you");
- auto fmt_single = [&](std::string tmpl_str) {
- minja::chat_template tmpl(tmpl_str, "", "");
- auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
+ auto fmt_single = [&](const std::string & tmpl_str) {
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
+ auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
#include <json.hpp>
#include <string>
-#include "chat-template.hpp"
-#include "chat.hpp"
+#include "chat.h"
#include "llama-grammar.h"
#include "unicode.h"
using json = nlohmann::ordered_json;
-static common_chat_msg msg_from_json(const json & message) {
- common_chat_msg ret;
- ret.role = "assistant";
- if (message.contains("content") && !message.at("content").is_null()) {
- ret.content = message.at("content");
- }
- if (message.contains("tool_plan")) {
- ret.reasoning_content = message.at("tool_plan");
- }
- if (message.contains("reasoning_content")) {
- ret.reasoning_content = message.at("reasoning_content");
- }
- auto has_tool_calls = message.contains("tool_calls");
- if (has_tool_calls) {
- for (const auto & tc : message.at("tool_calls")) {
- const auto & arguments = tc.at("function").at("arguments");
- ret.tool_calls.push_back({
- tc.at("function").at("name").get<std::string>(),
- arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
- tc.contains("id") ? tc.at("id").get<std::string>() : "",
- });
- }
- }
- return ret;
-}
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
}
static std::string read_file(const std::string & path) {
- std::cerr << "# Reading: " << path << std::endl << std::flush;
+ std::cerr << "# Reading: " << path << '\n' << std::flush;
std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) {
fs = std::ifstream("../" + path, std::ios_base::binary);
fs.seekg(0);
std::string out;
out.resize(static_cast<size_t>(size));
- fs.read(&out[0], static_cast<std::streamsize>(size));
+ fs.read(out.data(), static_cast<std::streamsize>(size));
return out;
}
+static common_chat_templates_ptr read_templates(const std::string & path) {
+ return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
+}
+
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
return std::unique_ptr<llama_grammar>(
llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
}
}
- for (const auto & stack : stacks_cur) {
- if (stack.empty()) {
- // An empty stack means that the grammar has been completed
- return true;
- }
+ if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
+ // An empty stack means that the grammar has been completed
+ return true;
}
return false;
}
-// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
-static std::string dump(const json & j) {
- return minja::Value(j).dump(-1, /* to_json= */ true);
-}
-
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
assert_equals(expected.role, actual.role);
assert_equals(expected.content, actual.content);
+ assert_equals(expected.content_parts.size(), actual.content_parts.size());
+ for (size_t i = 0; i < expected.content_parts.size(); i++) {
+ const auto & expected_part = expected.content_parts[i];
+ const auto & actual_part = actual.content_parts[i];
+ assert_equals(expected_part.type, actual_part.type);
+ assert_equals(expected_part.text, actual_part.text);
+ }
assert_equals(expected.reasoning_content, actual.reasoning_content);
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
const auto & expected_tool_call = expected.tool_calls[i];
const auto & actual_tool_call = actual.tool_calls[i];
assert_equals(expected_tool_call.name, actual_tool_call.name);
- assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
+ assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
assert_equals(expected_tool_call.id, actual_tool_call.id);
}
}
-const auto special_function_tool = json::parse(R"({
- "type": "function",
- "function": {
- "name": "special_function",
- "description": "I'm special",
- "parameters": {
- "type": "object",
- "properties": {
- "arg1": {
- "type": "integer",
- "description": "The arg."
- }
- },
- "required": ["arg1"]
- }
- }
-})");
-const auto python_tool = json::parse(R"({
- "type": "function",
- "function": {
- "name": "python",
- "description": "an ipython interpreter",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {
- "type": "string",
- "description": "Python code to execute."
- }
- },
- "required": ["code"]
- }
- }
-})");
-const auto code_interpreter_tool = json::parse(R"({
- "type": "function",
- "function": {
- "name": "code_interpreter",
- "description": "an ipython interpreter",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {
- "type": "string",
- "description": "Python code to execute."
- }
- },
- "required": ["code"]
- }
- }
-})");
-const json tools = { special_function_tool, python_tool };
-const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
+common_chat_tool special_function_tool {
+ /* .name = */ "special_function",
+ /* .description = */ "I'm special",
+ /* .parameters = */ R"({
+ "type": "object",
+ "properties": {
+ "arg1": {
+ "type": "integer",
+ "description": "The arg."
+ }
+ },
+ "required": ["arg1"]
+ })",
+};
+common_chat_tool python_tool {
+ /* .name = */ "python",
+ /* .description = */ "an ipython interpreter",
+ /* .parameters = */ R"({
+ "type": "object",
+ "properties": {
+ "code": {
+ "type": "string",
+ "description": "Python code to execute."
+ }
+ },
+ "required": ["code"]
+ })",
+};
+common_chat_tool code_interpreter_tool {
+ /* .name = */ "code_interpreter",
+ /* .description = */ "an ipython interpreter",
+ /* .parameters = */ R"({
+ "type": "object",
+ "properties": {
+ "code": {
+ "type": "string",
+ "description": "Python code to execute."
+ }
+ },
+ "required": ["code"]
+ })",
+};
+std::vector<common_chat_tool> tools { special_function_tool, python_tool };
+std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
struct delta_data {
std::string delta;
common_chat_params params;
};
-static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
- const json & user_message, const json & delta_message, const json & tools,
- const json & tool_choice,
+static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
+ const common_chat_msg & user_message,
+ const common_chat_msg & delta_message,
+ const std::vector<common_chat_tool> & tools,
+ const common_chat_tool_choice & tool_choice,
bool think = false) {
- common_chat_inputs inputs;
+ common_chat_templates_inputs inputs;
inputs.parallel_tool_calls = true;
- inputs.messages = json::array();
inputs.messages.push_back(user_message);
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.extract_reasoning = think;
- auto params_prefix = common_chat_params_init(tmpl, inputs);
+ auto params_prefix = common_chat_templates_apply(tmpls, inputs);
inputs.messages.push_back(delta_message);
inputs.add_generation_prompt = false;
- auto params_full = common_chat_params_init(tmpl, inputs);
+ auto params_full = common_chat_templates_apply(tmpls, inputs);
std::string prefix = params_prefix.prompt;
std::string full = params_full.prompt;
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
the parsed message is the same as the test_message
*/
-static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
- const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
+static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
+ const common_chat_msg & test_message,
+ const std::vector<common_chat_tool> & tools = {},
+ const std::string & expected_delta = "",
bool expect_grammar_triggered = true,
bool test_grammar_if_triggered = true,
bool think = false) {
- common_chat_msg expected_msg = msg_from_json(test_message);
-
- auto user_message = json{
- { "role", "user" },
- { "content", "Hello, world!" }
- };
+ common_chat_msg user_message;
+ user_message.role = "user";
+ user_message.content = "Hello, world!";
- for (const auto & tool_choice : json({ "auto", "required" })) {
- auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
+ for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
+ auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
}
if (expect_grammar_triggered) {
const auto msg = common_chat_parse(data.delta, data.params.format);
- assert_msg_equals(expected_msg, msg);
+ assert_msg_equals(test_message, msg);
}
- if (!expected_msg.tool_calls.empty()) {
+ if (!test_message.tool_calls.empty()) {
GGML_ASSERT(!data.params.grammar.empty());
}
if (!data.params.grammar.empty()) {
}
}
-static void test_template_output_parsers() {
- json message_user {
- { "role", "user" },
- { "content", "Hey there!" },
- };
- json message_assist {
- { "role", "assistant" },
- { "content", "Hello, world!\nWhat's up?" },
- };
- json message_assist_thoughts_unparsed_think {
- { "role", "assistant" },
- { "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
- };
- json message_assist_thoughts_unparsed_r7b {
- { "role", "assistant" },
- { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
- };
- json message_assist_thoughts {
- { "role", "assistant" },
- { "content", "Hello, world!\nWhat's up?" },
- { "reasoning_content", "I'm thinking" },
- };
- json tool_calls = json::array({{
- { "type", "function" },
- { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
- }});
-
- json message_assist_call {
- { "role", "assistant"},
- { "content", {}},
- { "tool_calls", {
- {
- { "type", "function" },
- { "function", {
- { "name", "special_function" },
- { "arguments", "{\"arg1\": 1}" },
- }},
- },
- }},
- };
- json message_assist_call_thoughts = {
- { "role", "assistant" },
- { "content", nullptr },
- { "reasoning_content", "I'm\nthinking" },
- { "tool_calls", {
- {
- { "type", "function" },
- { "function", {
- { "name", "special_function" },
- { "arguments", "{\"arg1\": 1}" },
- }},
- },
- }},
- };
- json message_assist_call_thoughts_unparsed = {
- { "role", "assistant" },
- { "content", "<think>I'm\nthinking</think>" },
- { "tool_calls", {
- {
- { "type", "function" },
- { "function", {
- { "name", "special_function" },
- { "arguments", "{\"arg1\": 1}" },
- }},
- },
- }},
- };
- json message_assist_call_id {
- { "role", "assistant"},
- { "content", {}},
- { "tool_calls", {
- {
- { "type", "function" },
- { "function", {
- { "name", "special_function" },
- { "arguments", "{\"arg1\": 1}" },
- }},
- {"id", "123456789"},
- },
- }},
- { "role", "assistant" },
- { "content", {} },
- { "tool_calls", tool_calls }
- };
- json message_assist_call_idx {
- { "role", "assistant"},
- { "content", {}},
- { "tool_calls", {
- {
- { "type", "function" },
- { "function", {
- { "name", "special_function" },
- { "arguments", "{\"arg1\": 1}" },
- }},
- // Index of the tool call in the tool_calls array
- {"id", "0"},
- },
- }},
- { "role", "assistant" },
- { "content", {} },
- { "tool_calls", tool_calls }
- };
- json message_assist_call_tool_plan_idx = message_assist_call_idx;
- message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
-
- auto python_message_assist_call = json{
- { "role", "assistant" },
- { "content", {} },
- { "tool_calls", json{ {
- { "type", "function" },
- { "function",
- {
- { "name", "python" },
- { "arguments",
- {
- { "code", "print('hey')" },
- } },
- } },
- } } }
+const common_chat_msg message_user {
+ "user",
+ "Hey there!",
+ /* .content_parts = */ {},
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+
+const common_chat_msg message_user_parts {
+ "user",
+ /* .content = */ "",
+ /* .content_parts = */ {
+ { "text", "Hey" },
+ { "text", "there" },
+ },
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist {
+ "assistant",
+ "Hello, world!\nWhat's up?",
+ /* .content_parts = */ {},
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts_unparsed_think {
+ "assistant",
+ "<think>I'm thinking</think>Hello, world!\nWhat's up?",
+ /* .content_parts = */ {},
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts_unparsed_r7b {
+ "assistant",
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
+ /* .content_parts = */ {},
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_thoughts {
+ "assistant",
+ "Hello, world!\nWhat's up?",
+ /* .content_parts = */ {},
+ /* .tool_calls = */ {},
+ /* .reasoning_content = */ "I'm thinking",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const std::vector<common_chat_tool_call> tool_calls {
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
+};
+const std::vector<common_chat_tool_call> tool_calls_idx {
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
+};
+const std::vector<common_chat_tool_call> tool_calls_id {
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
+};
+
+const common_chat_msg message_assist_call {
+ "assistant",
+ "",
+ /* .content_parts = */ {},
+ tool_calls,
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_thoughts = {
+ "assistant",
+ /* .content = */ "",
+ /* .content_parts = */ {},
+ tool_calls,
+ /* .reasoning_content = */ "I'm\nthinking",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_thoughts_unparsed = {
+ "assistant",
+ /* .content = */ "<think>I'm\nthinking</think>",
+ /* .content_parts = */ {},
+ tool_calls,
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_id {
+ "assistant",
+ "",
+ /* .content_parts = */ {},
+ tool_calls_id,
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_idx {
+ "assistant",
+ "",
+ /* .content_parts = */ {},
+ tool_calls_idx,
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_python {
+ "assistant",
+ "",
+ /* .content_parts = */ {},
+ { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+const common_chat_msg message_assist_call_code_interpreter {
+ "assistant",
+ "",
+ /* .content_parts = */ {},
+ { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
+ /* .reasoning_content = */ "",
+ /* .tool_name = */ "",
+ /* .tool_call_id = */ "",
+};
+
+static void test_msgs_oaicompat_json_conversion() {
+ std::vector<common_chat_msg> msgs{
+ message_user,
+ message_user_parts,
+ message_assist_call,
+ message_assist_call_thoughts,
+ message_assist_call_thoughts_unparsed,
+ message_assist_call_id,
+ message_assist_call_idx,
+ message_assist_call_python,
+ message_assist_call_code_interpreter,
};
- auto code_interpreter_message_assist_call = json{
- { "role", "assistant" },
- { "content", {} },
- { "tool_calls", json{ {
- { "type", "function" },
- { "function",
- {
- { "name", "code_interpreter" },
- { "arguments",
- {
- { "code", "print('hey')" },
- } },
- } },
- } } }
+ for (const auto & msg : msgs) {
+ auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
+ auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
+ assert_equals((size_t) 1, msgs2.size());
+ auto msg2 = msgs2[0];
+ assert_msg_equals(msg, msg2);
+ }
+ assert_equals(
+ std::string(
+ "[\n"
+ " {\n"
+ " \"role\": \"user\",\n"
+ " \"content\": [\n"
+ " {\n"
+ " \"type\": \"text\",\n"
+ " \"text\": \"Hey\"\n"
+ " },\n"
+ " {\n"
+ " \"type\": \"text\",\n"
+ " \"text\": \"there\"\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ "]"
+ ),
+ common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
+
+ assert_equals(
+ std::string(
+ "[\n"
+ " {\n"
+ " \"role\": \"assistant\",\n"
+ " \"content\": null,\n"
+ " \"tool_calls\": [\n"
+ " {\n"
+ " \"type\": \"function\",\n"
+ " \"function\": {\n"
+ " \"name\": \"python\",\n"
+ " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ "]"
+ ),
+ common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
+}
+
+static void test_tools_oaicompat_json_conversion() {
+ std::vector<common_chat_tool> tools{
+ special_function_tool,
+ python_tool,
+ code_interpreter_tool,
};
- common_chat_inputs inputs_no_tools;
- inputs_no_tools.messages = json::array({message_user});
+ for (const auto & tool : tools) {
+ auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
+ auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
+ assert_equals((size_t) 1, tools2.size());
+ auto tool2 = tools2[0];
+ assert_equals(tool.name, tool2.name);
+ assert_equals(tool.description, tool2.description);
+ assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
+ }
+
+ assert_equals(
+ std::string(
+ "[\n"
+ " {\n"
+ " \"type\": \"function\",\n"
+ " \"function\": {\n"
+ " \"name\": \"special_function\",\n"
+ " \"description\": \"I'm special\",\n"
+ " \"parameters\": {\n"
+ " \"type\": \"object\",\n"
+ " \"properties\": {\n"
+ " \"arg1\": {\n"
+ " \"type\": \"integer\",\n"
+ " \"description\": \"The arg.\"\n"
+ " }\n"
+ " },\n"
+ " \"required\": [\n"
+ " \"arg1\"\n"
+ " ]\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "]"
+ ),
+ common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
+}
+
+static void test_template_output_parsers() {
+
+ common_chat_templates_inputs inputs_no_tools;
+ inputs_no_tools.messages = {message_user};
inputs_no_tools.extract_reasoning = false;
- common_chat_inputs inputs_no_tools_think;
- inputs_no_tools_think.messages = json::array({message_user});
+ common_chat_templates_inputs inputs_no_tools_think;
+ inputs_no_tools_think.messages = {message_user};
inputs_no_tools_think.extract_reasoning = true;
- common_chat_inputs inputs_tools;
- inputs_tools.messages = json::array({message_user});
- inputs_tools.tools = json::array({special_function_tool});
+ common_chat_templates_inputs inputs_tools;
+ inputs_tools.messages = {message_user};
+ inputs_tools.tools = {special_function_tool};
inputs_tools.extract_reasoning = false;
- common_chat_inputs inputs_tools_think;
- inputs_tools_think.messages = json::array({message_user});
- inputs_tools_think.tools = json::array({special_function_tool});
+ common_chat_templates_inputs inputs_tools_think;
+ inputs_tools_think.messages = {message_user};
+ inputs_tools_think.tools = {special_function_tool};
inputs_tools_think.extract_reasoning = true;
- common_chat_inputs inputs_tools_builtin;
- inputs_tools_builtin.messages = json::array({message_user});
- inputs_tools_builtin.tools = json::array({python_tool});
+ common_chat_templates_inputs inputs_tools_builtin;
+ inputs_tools_builtin.messages = {message_user};
+ inputs_tools_builtin.tools = {python_tool};
inputs_tools_builtin.extract_reasoning = false;
{
// Not supported yet
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
}
{
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
- assert_msg_equals(msg_from_json(message_assist),
+ assert_msg_equals(message_assist,
common_chat_parse(
"Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_COMMAND_R7B));
- assert_msg_equals(msg_from_json(message_assist),
+ assert_msg_equals(message_assist,
common_chat_parse(
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
- assert_msg_equals(msg_from_json(message_assist),
+ assert_msg_equals(message_assist,
common_chat_parse(
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
- assert_msg_equals(msg_from_json(message_assist_thoughts),
+ assert_msg_equals(message_assist_thoughts,
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
- test_template(tmpl, end_tokens, message_assist_call_idx, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
"<|START_THINKING|><|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
- test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
- "<|START_THINKING|>I'm thinking<|END_THINKING|>"
- "<|START_ACTION|>[\n"
- " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
- "]<|END_ACTION|>",
- /* expect_grammar_triggered= */ true,
- /* test_grammar_if_triggered= */ true,
- /* think= */ true);
- test_template(tmpl, end_tokens, message_assist, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools,
"<|START_RESPONSE|>Hello, world!\n"
"What's up?<|END_RESPONSE|>",
/* expect_grammar_triggered= */ false);
}
{
- const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
+ auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
std::vector<std::string> end_tokens{ "<end_of_turn>" };
- assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
- common_chat_params_init(
- common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
- "<s>", "</s>"),
+ common_chat_templates_apply(
+ read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
inputs_tools)
.format);
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
- assert_msg_equals(msg_from_json(message_assist),
+ assert_msg_equals(message_assist,
common_chat_parse("{\n"
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
"}",
- common_chat_params_init(tmpl, inputs_tools).format));
- test_template(tmpl, end_tokens, message_assist_call_id, tools,
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format));
+ test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
"}");
}
{
- const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
std::vector<std::string> end_tokens{ "</s>" };
- assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(
- tmpl, end_tokens, message_assist_call_id, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(
+ tmpls.get(), end_tokens, message_assist_call_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
}
{
- const common_chat_template tmpl(
- read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
+ auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
std::vector<std::string> end_tokens{ "<|im_end|>" };
- assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(
COMMON_CHAT_FORMAT_HERMES_2_PRO,
- common_chat_params_init(
- common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
- "<s>", "</s>"),
+ common_chat_templates_apply(
+ read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
inputs_tools)
.format);
assert_equals(
COMMON_CHAT_FORMAT_HERMES_2_PRO,
- common_chat_params_init(
- common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
+ common_chat_templates_apply(
+ read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
inputs_tools)
.format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
- test_template(tmpl, end_tokens, python_message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
"<tool_call>\n"
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
"</tool_call>");
}
{
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
- common_chat_params_init(tmpl, inputs_tools_builtin).format);
+ common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
- common_chat_params_init(
- common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
- "<s>", "</s>"),
+ common_chat_templates_apply(
+ read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
inputs_tools_builtin)
.format);
- // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
+ // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
- test_template(tmpl, end_tokens, python_message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
"<|python_tag|>python.call(code=\"print('hey')\")");
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
- common_chat_params_init(tmpl, inputs_tools).format);
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
{
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
- test_template(tmpl, end_tokens, message_assist, {},
+ test_templates(tmpls.get(), end_tokens, message_assist, {},
"all\n"
"Hello, world!\n"
"What's up?",
/* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"special_function\n"
"{\"arg1\": 1}");
}
{
- const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
- "</s>");
+ auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
std::vector<std::string> end_tokens{ "<|eot_id|>" };
- assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
{
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
- const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
- "<s>", "</s>");
+ auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
- assert_msg_equals(msg_from_json(message_assist_thoughts),
+ assert_msg_equals(message_assist_thoughts,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
- assert_msg_equals(msg_from_json(message_assist_thoughts),
+ assert_msg_equals(message_assist_thoughts,
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
- // test_template(tmpl, end_tokens, message_assist_call, tools,
+ // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
// "```json\n"
// "{\"arg1\": 1}\n"
}
{
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
- const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
- "<s>", "</s>");
+ auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
- assert_msg_equals(msg_from_json(message_assist_thoughts),
+ assert_msg_equals(message_assist_thoughts,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
- assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
+ assert_msg_equals(message_assist_call_thoughts_unparsed,
common_chat_parse(
"<think>I'm\nthinking</think>\n\n"
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
"{\"arg1\": 1}\n"
"```<|tool▁call▁end|><|tool▁calls▁end|>",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
- assert_msg_equals(msg_from_json(message_assist_call_thoughts),
+ assert_msg_equals(message_assist_call_thoughts,
common_chat_parse(
"<think>I'm\nthinking</think>\n\n"
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
"{\"arg1\": 1}\n"
"```<|tool▁call▁end|><|tool▁calls▁end|>",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
- test_template(tmpl, end_tokens, message_assist_call, tools,
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
"```json\n"
"{\"arg1\": 1}\n"
}
int main(int argc, char ** argv) {
+ try {
#ifndef _WIN32
- if (argc > 1) {
- common_chat_inputs inputs;
- inputs.messages = {
- { { "role", "user" }, { "content", "Hey" } }
- };
- inputs.tools = json::array({ special_function_tool });
-
- std::cout << "| Template | Format |\n";
- std::cout << "|----------|--------|\n";
-
- for (int i = 1; i < argc; i++) {
- try {
- std::string path = argv[i];
- if (path.rfind(".jinja") != path.size() - 6) {
- std::cerr << "Skipping non-jinja file: " << path << std::endl;
- continue;
+ if (argc > 1) {
+ common_chat_templates_inputs inputs;
+ common_chat_msg msg;
+ msg.role = "user";
+ msg.content = "Hey";
+ inputs.messages = {msg};
+ inputs.tools = { special_function_tool };
+
+ std::cout << "| Template | Format |\n";
+ std::cout << "|----------|--------|\n";
+
+ for (int i = 1; i < argc; i++) {
+ try {
+ std::string path = argv[i];
+ if (path.rfind(".jinja") != path.size() - 6) {
+ std::cerr << "Skipping non-jinja file: " << path << '\n';
+ continue;
+ }
+ auto tmpls = read_templates(path);
+ auto parts = string_split(path, "/");
+ auto name = parts[parts.size() - 1];
+ auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
+ std::cout << "| " << name << " | " << format << " |\n";
+ } catch (const std::exception & e) {
+ std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
}
- common_chat_template tmpl(read_file(path), "", "");
- auto parts = string_split(path, "/");
- auto name = parts[parts.size() - 1];
- auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
- std::cout << "| " << name << " | " << format << " |\n";
- } catch (const std::exception & e) {
- std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
}
- }
- } else
+ } else
#endif
- {
- test_template_output_parsers();
- std::cout << "\n[chat] All tests passed!" << std::endl;
+ {
+ test_msgs_oaicompat_json_conversion();
+ test_tools_oaicompat_json_conversion();
+ test_template_output_parsers();
+ std::cout << "\n[chat] All tests passed!" << '\n';
+ }
+ return 0;
+ } catch (const std::exception & e) {
+ std::cerr << "Error: " << e.what() << '\n';
+ return 1;
}
- return 0;
}