]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
minja: sync (qwen3) (#13573)
authorOlivier Chafik <redacted>
Thu, 15 May 2025 22:29:10 +0000 (23:29 +0100)
committerGitHub <redacted>
Thu, 15 May 2025 22:29:10 +0000 (23:29 +0100)
* minja: sync https://github.com/google/minja/commit/f06140fa52fd140fe38e531ec373d8dc9c86aa06

- https://github.com/google/minja/pull/67 (@grf53)
- https://github.com/google/minja/pull/66 (@taha-yassine)
- https://github.com/google/minja/pull/63 (@grf53)
- https://github.com/google/minja/pull/58

---------

Co-authored-by: ochafik <redacted>
common/minja/chat-template.hpp
common/minja/minja.hpp

index 237a7625cddbdf0e4daf0fd6839f0304e06abb34..c930a587a89af0d9077e3df77dbe3ea63ad8b73e 100644 (file)
 #include <chrono>
 #include <cstddef>
 #include <cstdio>
+#include <ctime>
 #include <exception>
 #include <iomanip>
 #include <memory>
 #include <sstream>
+#include <stdexcept>
 #include <string>
 #include <vector>
 
@@ -393,8 +395,8 @@ class chat_template {
 
             for (const auto & message_ : adjusted_messages) {
                 auto message = message_;
-                if (!message.contains("role") || !message.contains("content")) {
-                    throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
+                if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
+                    throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
                 }
                 std::string role = message.at("role");
 
@@ -415,7 +417,6 @@ class chat_template {
                         }
                     }
                     if (polyfill_tool_calls) {
-                        auto content = message.at("content");
                         auto tool_calls = json::array();
                         for (const auto & tool_call : message.at("tool_calls")) {
                             if (tool_call.at("type") != "function") {
@@ -434,8 +435,11 @@ class chat_template {
                         auto obj = json {
                             {"tool_calls", tool_calls},
                         };
-                        if (!content.is_null() && !content.empty()) {
-                            obj["content"] = content;
+                        if (message.contains("content")) {
+                            auto content = message.at("content");
+                            if (!content.is_null() && !content.empty()) {
+                                obj["content"] = content;
+                            }
                         }
                         message["content"] = obj.dump(2);
                         message.erase("tool_calls");
index e52e792d844d6b0536c70417c199dc1519338dd9..b3b00547d81422df15dc862e56e762a5485d2e76 100644 (file)
@@ -11,6 +11,7 @@
 #include <algorithm>
 #include <cctype>
 #include <cstddef>
+#include <cstdint>
 #include <cmath>
 #include <exception>
 #include <functional>
@@ -233,7 +234,7 @@ public:
       }
     } else if (is_object()) {
       if (!index.is_hashable())
-        throw std::runtime_error("Unashable type: " + index.dump());
+        throw std::runtime_error("Unhashable type: " + index.dump());
       auto it = object_->find(index.primitive_);
       if (it == object_->end())
         throw std::runtime_error("Key not found: " + index.dump());
@@ -252,7 +253,7 @@ public:
       auto index = key.get<int>();
       return array_->at(index < 0 ? array_->size() + index : index);
     } else if (object_) {
-      if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+      if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
       auto it = object_->find(key.primitive_);
       if (it == object_->end()) return Value();
       return it->second;
@@ -261,7 +262,7 @@ public:
   }
   void set(const Value& key, const Value& value) {
     if (!object_) throw std::runtime_error("Value is not an object: " + dump());
-    if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
     (*object_)[key.primitive_] = value;
   }
   Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@@ -398,7 +399,7 @@ public:
       }
       return false;
     } else if (object_) {
-      if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
+      if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
       return object_->find(value.primitive_) != object_->end();
     } else {
       throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@@ -416,7 +417,7 @@ public:
     return const_cast<Value*>(this)->at(index);
   }
   Value& at(const Value & index) {
-    if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
     if (is_array()) return array_->at(index.get<int>());
     if (is_object()) return object_->at(index.primitive_);
     throw std::runtime_error("Value is not an array or object: " + dump());
@@ -676,8 +677,8 @@ public:
 class VariableExpr : public Expression {
     std::string name;
 public:
-    VariableExpr(const Location & location, const std::string& n)
-      : Expression(location), name(n) {}
+    VariableExpr(const Location & loc, const std::string& n)
+      : Expression(loc), name(n) {}
     std::string get_name() const { return name; }
     Value do_evaluate(const std::shared_ptr<Context> & context) const override {
         if (!context->contains(name)) {
@@ -1200,9 +1201,9 @@ public:
 
 class SliceExpr : public Expression {
 public:
-    std::shared_ptr<Expression> start, end;
-    SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
-      : Expression(loc), start(std::move(s)), end(std::move(e)) {}
+    std::shared_ptr<Expression> start, end, step;
+    SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
+      : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
     Value do_evaluate(const std::shared_ptr<Context> &) const override {
         throw std::runtime_error("SliceExpr not implemented");
     }
@@ -1219,18 +1220,35 @@ public:
         if (!index) throw std::runtime_error("SubscriptExpr.index is null");
         auto target_value = base->evaluate(context);
         if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
-          auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
-          auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
+          auto len = target_value.size();
+          auto wrap = [len](int64_t i) -> int64_t {
+            if (i < 0) {
+              return i + len;
+            }
+            return i;
+          };
+          int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
+          if (!step) {
+            throw std::runtime_error("slice step cannot be zero");
+          }
+          int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
+          int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
           if (target_value.is_string()) {
             std::string s = target_value.get<std::string>();
-            if (start < 0) start = s.size() + start;
-            if (end < 0) end = s.size() + end;
-            return s.substr(start, end - start);
+
+            std::string result;
+            if (start < end && step == 1) {
+              result = s.substr(start, end - start);
+            } else {
+              for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
+                result += s[i];
+              }
+            }
+            return result;
+
           } else if (target_value.is_array()) {
-            if (start < 0) start = target_value.size() + start;
-            if (end < 0) end = target_value.size() + end;
             auto result = Value::array();
-            for (auto i = start; i < end; ++i) {
+            for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
               result.push_back(target_value.at(i));
             }
             return result;
@@ -1305,6 +1323,8 @@ public:
               if (name == "iterable") return l.is_iterable();
               if (name == "sequence") return l.is_array();
               if (name == "defined") return !l.is_null();
+              if (name == "true") return l.to_bool();
+              if (name == "false") return !l.to_bool();
               throw std::runtime_error("Unknown type for 'is' operator: " + name);
             };
             auto value = eval();
@@ -1520,6 +1540,10 @@ public:
             vargs.expectArgs("endswith method", {1, 1}, {0, 0});
             auto suffix = vargs.args[0].get<std::string>();
             return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
+          } else if (method->get_name() == "startswith") {
+            vargs.expectArgs("startswith method", {1, 1}, {0, 0});
+            auto prefix = vargs.args[0].get<std::string>();
+            return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
           } else if (method->get_name() == "title") {
             vargs.expectArgs("title method", {0, 0}, {0, 0});
             auto res = str;
@@ -2082,28 +2106,37 @@ private:
 
       while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
         if (!consumeToken("[").empty()) {
-            std::shared_ptr<Expression> index;
+          std::shared_ptr<Expression> index;
+          auto slice_loc = get_location();
+          std::shared_ptr<Expression> start, end, step;
+          bool has_first_colon = false, has_second_colon = false;
+
+          if (!peekSymbols({ ":" })) {
+            start = parseExpression();
+          }
+
+          if (!consumeToken(":").empty()) {
+            has_first_colon = true;
+            if (!peekSymbols({ ":", "]" })) {
+              end = parseExpression();
+            }
             if (!consumeToken(":").empty()) {
-              auto slice_end = parseExpression();
-              index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
-            } else {
-              auto slice_start = parseExpression();
-              if (!consumeToken(":").empty()) {
-                consumeSpaces();
-                if (peekSymbols({ "]" })) {
-                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
-                } else {
-                  auto slice_end = parseExpression();
-                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
-                }
-              } else {
-                index = std::move(slice_start);
+              has_second_colon = true;
+              if (!peekSymbols({ "]" })) {
+                step = parseExpression();
               }
             }
-            if (!index) throw std::runtime_error("Empty index in subscript");
-            if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
+          }
+
+          if ((has_first_colon || has_second_colon) && (start || end || step)) {
+            index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
+          } else {
+            index = std::move(start);
+          }
+          if (!index) throw std::runtime_error("Empty index in subscript");
+          if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
 
-            value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
+          value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
         } else if (!consumeToken(".").empty()) {
             auto identifier = parseIdentifier();
             if (!identifier) throw std::runtime_error("Expected identifier in subscript");