]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vendor : sync minja (#16500)
authorSigbjørn Skjæret <redacted>
Wed, 29 Oct 2025 13:09:50 +0000 (14:09 +0100)
committerGitHub <redacted>
Wed, 29 Oct 2025 13:09:50 +0000 (14:09 +0100)
* sync minja.hpp

Adds Call/EndCall support, used in MiniCPM3 and MiniCPM4-MCP.

* remove spurious semicolon

* sync from ochafik/minja

vendor/minja/minja.hpp

index dad75efbba5f0c78b94e0181bb7b8a5a6f69e995..57b138add6adb12a956c4044b269966456b89848 100644 (file)
@@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) {
 }
 
 /* Values that behave roughly like in Python. */
-class Value : public std::enable_shared_from_this<Value> {
+class Value {
 public:
   using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
   using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
@@ -158,12 +158,14 @@ public:
   Value(const json & v) {
     if (v.is_object()) {
       auto object = std::make_shared<ObjectType>();
+      object->reserve(v.size());
       for (auto it = v.begin(); it != v.end(); ++it) {
-        (*object)[it.key()] = it.value();
+        object->emplace_back(it.key(), Value(it.value()));
       }
       object_ = std::move(object);
     } else if (v.is_array()) {
       auto array = std::make_shared<ArrayType>();
+      array->reserve(v.size());
       for (const auto& item : v) {
         array->push_back(Value(item));
       }
@@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
   return out.str();
 }
 
-class Context : public std::enable_shared_from_this<Context> {
+class Context {
   protected:
     Value values_;
     std::shared_ptr<Context> parent_;
@@ -706,7 +708,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
 
 class TemplateToken {
 public:
-    enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
+    enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };
 
     static std::string typeToString(Type t) {
         switch (t) {
@@ -729,6 +731,8 @@ public:
             case Type::EndGeneration: return "endgeneration";
             case Type::Break: return "break";
             case Type::Continue: return "continue";
+            case Type::Call: return "call";
+            case Type::EndCall: return "endcall";
         }
         return "Unknown";
     }
@@ -846,6 +850,17 @@ struct LoopControlTemplateToken : public TemplateToken {
     LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
 };
 
+struct CallTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> expr;
+    CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
+        : TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
+};
+
+struct EndCallTemplateToken : public TemplateToken {
+    EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
+        : TemplateToken(Type::EndCall, loc, pre, post) {}
+};
+
 class TemplateNode {
     Location location_;
 protected:
@@ -1047,36 +1062,48 @@ public:
           }
         }
     }
-    void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
         if (!name) throw std::runtime_error("MacroNode.name is null");
         if (!body) throw std::runtime_error("MacroNode.body is null");
-        auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
-            auto call_context = macro_context;
+
+        // Use init-capture to avoid dangling 'this' pointer and circular references
+        auto callable = Value::callable([weak_context = std::weak_ptr<Context>(context),
+                                         name = name, params = params, body = body,
+                                         named_param_positions = named_param_positions]
+                                        (const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
+            auto context_locked = weak_context.lock();
+            if (!context_locked) throw std::runtime_error("Macro context no longer valid");
+            auto execution_context = Context::make(Value::object(), context_locked);
+
+            if (call_context->contains("caller")) {
+                execution_context->set("caller", call_context->get("caller"));
+            }
+
             std::vector<bool> param_set(params.size(), false);
             for (size_t i = 0, n = args.args.size(); i < n; i++) {
                 auto & arg = args.args[i];
                 if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
                 param_set[i] = true;
-                auto & param_name = params[i].first;
-                call_context->set(param_name, arg);
+                const auto & param_name = params[i].first;
+                execution_context->set(param_name, arg);
             }
             for (auto & [arg_name, value] : args.kwargs) {
                 auto it = named_param_positions.find(arg_name);
                 if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
 
-                call_context->set(arg_name, value);
+                execution_context->set(arg_name, value);
                 param_set[it->second] = true;
             }
             // Set default values for parameters that were not passed
             for (size_t i = 0, n = params.size(); i < n; i++) {
                 if (!param_set[i] && params[i].second != nullptr) {
-                    auto val = params[i].second->evaluate(context);
-                    call_context->set(params[i].first, val);
+                    auto val = params[i].second->evaluate(call_context);
+                    execution_context->set(params[i].first, val);
                 }
             }
-            return body->render(call_context);
+            return body->render(execution_context);
         });
-        macro_context->set(name->get_name(), callable);
+        context->set(name->get_name(), callable);
     }
 };
 
@@ -1611,6 +1638,44 @@ public:
     }
 };
 
+class CallNode : public TemplateNode {
+    std::shared_ptr<Expression> expr;
+    std::shared_ptr<TemplateNode> body;
+
+public:
+    CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
+        : TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}
+
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+        if (!expr) throw std::runtime_error("CallNode.expr is null");
+        if (!body) throw std::runtime_error("CallNode.body is null");
+
+        // Use init-capture to avoid dangling 'this' pointer and circular references
+        auto caller = Value::callable([weak_context = std::weak_ptr<Context>(context), body=body]
+                                      (const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
+            auto context_locked = weak_context.lock();
+            if (!context_locked) throw std::runtime_error("Caller context no longer valid");
+            return Value(body->render(context_locked));
+        });
+
+        context->set("caller", caller);
+
+        auto call_expr = dynamic_cast<CallExpr*>(expr.get());
+        if (!call_expr) {
+            throw std::runtime_error("Invalid call block syntax - expected function call");
+        }
+
+        Value function = call_expr->object->evaluate(context);
+        if (!function.is_callable()) {
+            throw std::runtime_error("Call target must be callable: " + function.dump());
+        }
+        ArgumentsValue args = call_expr->args.evaluate(context);
+
+        Value result = function.call(context, args);
+        out << result.to_str();
+    }
+};
+
 class FilterExpr : public Expression {
     std::vector<std::shared_ptr<Expression>> parts;
 public:
@@ -2320,7 +2385,7 @@ private:
       static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
       static std::regex expr_open_regex(R"(\{\{([-~])?)");
       static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
-      static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
+      static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
       static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
       static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
       static std::regex block_close_regex(R"(\s*([-~])?%\})");
@@ -2443,6 +2508,15 @@ private:
             } else if (keyword == "endmacro") {
               auto post_space = parseBlockClose();
               tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "call") {
+              auto expr = parseExpression();
+              if (!expr) throw std::runtime_error("Expected expression in call block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
+            } else if (keyword == "endcall") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
             } else if (keyword == "filter") {
               auto filter = parseExpression();
               if (!filter) throw std::runtime_error("Expected expression in filter block");
@@ -2575,6 +2649,12 @@ private:
                   throw unterminated(**start);
               }
               children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
+          } else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
+            auto body = parseTemplate(begin, it, end);
+            if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
+                throw unterminated(**start);
+            }
+            children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
           } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
               auto body = parseTemplate(begin, it, end);
               if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
@@ -2588,6 +2668,7 @@ private:
           } else if (dynamic_cast<EndForTemplateToken*>(token.get())
                   || dynamic_cast<EndSetTemplateToken*>(token.get())
                   || dynamic_cast<EndMacroTemplateToken*>(token.get())
+                  || dynamic_cast<EndCallTemplateToken*>(token.get())
                   || dynamic_cast<EndFilterTemplateToken*>(token.get())
                   || dynamic_cast<EndIfTemplateToken*>(token.get())
                   || dynamic_cast<ElseTemplateToken*>(token.get())