]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sync: minja (https://github.com/google/minja/commit/a72057e5190de2c612d4598bb10b4bfd0...
authorOlivier Chafik <redacted>
Mon, 10 Feb 2025 09:34:09 +0000 (09:34 +0000)
committerGitHub <redacted>
Mon, 10 Feb 2025 09:34:09 +0000 (09:34 +0000)
common/chat-template.hpp
common/minja.hpp

index 0e88fb3617e9b26187cb4a180117b54eb17e61d0..882ba41bd1bb0fc5cc22bdf5c258b2b68f60a23e 100644 (file)
@@ -249,16 +249,30 @@ class chat_template {
                     inputs.add_generation_prompt = false;
                     full = apply(inputs);
                 }
-
-                if (full.find(prefix) != 0) {
-                    if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
-                        prefix = prefix.substr(0, prefix.size() - eos_token_.size());
+                auto eos_pos_last = full.rfind(eos_token_);
+                if (eos_pos_last == prefix.size() - eos_token_.size() ||
+                      (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
+                    full = full.substr(0, eos_pos_last);
+                }
+                size_t common_prefix_length = 0;
+                for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
+                    if (prefix[i] != full[i]) {
+                        break;
                     }
+                    if (prefix[i] == '<') {
+                        // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
+                        // but it removes thinking tags for past messages.
+                        // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
+                        continue;
+                    }
+                    common_prefix_length = i + 1;
                 }
-                if (full.find(prefix) != 0) {
+                auto example = full.substr(common_prefix_length);
+                if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
                     fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
+                } else {
+                    tool_call_example_ = example;
                 }
-                tool_call_example_ = full.substr(prefix.size());
             }
         } catch (const std::exception & e) {
             fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
@@ -363,7 +377,7 @@ class chat_template {
             if (polyfill_tools) {
                 adjusted_messages = add_system(inputs.messages,
                     "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
-                    (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
+                    (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
             } else {
                 adjusted_messages = inputs.messages;
             }
index c304b5c66a0921f43100bb03609e515a034d3b5a..c58dd66e067b1c24b930bda2f59d168bbb1ce85b 100644 (file)
@@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) {
   return s.substr(start, end - start + 1);
 }
 
+static std::string capitalize(const std::string & s) {
+  if (s.empty()) return s;
+  auto result = s;
+  result[0] = std::toupper(result[0]);
+  return result;
+}
+
 static std::string html_escape(const std::string & s) {
   std::string result;
   result.reserve(s.size());
@@ -1462,6 +1469,9 @@ public:
           if (method->get_name() == "strip") {
             vargs.expectArgs("strip method", {0, 0}, {0, 0});
             return Value(strip(str));
+          } else if (method->get_name() == "capitalize") {
+            vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
+            return Value(capitalize(str));
           } else if (method->get_name() == "endswith") {
             vargs.expectArgs("endswith method", {1, 1}, {0, 0});
             auto suffix = vargs.args[0].get<std::string>();
@@ -1792,7 +1802,7 @@ private:
         auto left = parseStringConcat();
         if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
 
-        static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
+        static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
         static std::regex not_tok(R"(not\b)");
         std::string op_str;
         while (!(op_str = consumeToken(compare_tok)).empty()) {
@@ -2171,7 +2181,7 @@ private:
     using TemplateTokenIterator = TemplateTokenVector::const_iterator;
 
     std::vector<std::string> parseVarNames() {
-      static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
+      static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
 
       std::vector<std::string> group;
       if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
@@ -2194,13 +2204,13 @@ private:
     }
 
     TemplateTokenVector tokenize() {
-      static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
+      static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
       static std::regex expr_open_regex(R"(\{\{([-~])?)");
-      static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
+      static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
       static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
       static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
-      static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
-      static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
+      static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
+      static std::regex block_close_regex(R"(\s*([-~])?%\})");
 
       TemplateTokenVector tokens;
       std::vector<std::string> group;
@@ -2284,7 +2294,7 @@ private:
               auto post_space = parseBlockClose();
               tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
             } else if (keyword == "set") {
-              static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
+              static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
 
               std::string ns;
               std::vector<std::string> var_names;
@@ -2336,6 +2346,11 @@ private:
               throw std::runtime_error("Unexpected block: " + keyword);
             }
           } else if (std::regex_search(it, end, match, non_text_open_regex)) {
+            if (!match.position()) {
+                if (match[0] != "{#")
+                    throw std::runtime_error("Internal error: Expected a comment");
+                throw std::runtime_error("Missing end of comment tag");
+            }
             auto text_end = it + match.position();
             text = std::string(it, text_end);
             it = text_end;
@@ -2400,7 +2415,7 @@ private:
 
               auto text = text_token->text;
               if (post_space == SpaceHandling::Strip) {
-                static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
+                static std::regex trailing_space_regex(R"(\s+$)");
                 text = std::regex_replace(text, trailing_space_regex, "");
               } else if (options.lstrip_blocks && it != end) {
                 auto i = text.size();
@@ -2410,7 +2425,7 @@ private:
                 }
               }
               if (pre_space == SpaceHandling::Strip) {
-                static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
+                static std::regex leading_space_regex(R"(^\s+)");
                 text = std::regex_replace(text, leading_space_regex, "");
               } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
                 if (text.length() > 0 && text[0] == '\n') {