]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix `llama_chat_format_single` for mistral (#8657)
authorXuan Son Nguyen <redacted>
Wed, 24 Jul 2024 11:48:46 +0000 (13:48 +0200)
committerGitHub <redacted>
Wed, 24 Jul 2024 11:48:46 +0000 (13:48 +0200)
* fix `llama_chat_format_single` for mistral

* fix typo

* use printf

common/common.cpp
examples/main/main.cpp
tests/test-chat-template.cpp

index 4c19132f19832d0cf5279b0fd50606529dc715b8..ec44a05521c9d7379779239a03f09da734b77951 100644 (file)
@@ -2723,7 +2723,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
         const llama_chat_msg & new_msg,
         bool add_ass) {
     std::ostringstream ss;
-    auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
+    auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
     std::vector<llama_chat_msg> chat_new(past_msg);
     // if the past_msg ends with a newline, we must preserve it in the formatted version
     if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
index a0d817b1a89d14d70333fc596b8b5024363bedb6..61e960ea2abe6a4d09b02f703ae17fb8be9b3f7c 100644 (file)
@@ -124,6 +124,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
     auto formatted = llama_chat_format_single(
         model, g_params->chat_template, chat_msgs, new_msg, role == "user");
     chat_msgs.push_back({role, content});
+    LOG("formatted: %s\n", formatted.c_str());
     return formatted;
 }
 
index 6583dd0b2b39b652216bb1f1c5eae58a6cbf698d..46a7d3aea8f67ae044a6a01c492ea3043ccd8625 100644 (file)
@@ -1,4 +1,3 @@
-#include <iostream>
 #include <string>
 #include <vector>
 #include <sstream>
@@ -133,13 +132,31 @@ int main(void) {
         );
         formatted_chat.resize(res);
         std::string output(formatted_chat.data(), formatted_chat.size());
-        std::cout << output << "\n-------------------------\n";
+        printf("%s\n", output.c_str());
+        printf("-------------------------\n");
         assert(output == expected);
     }
 
-    // test llama_chat_format_single
-    std::cout << "\n\n=== llama_chat_format_single ===\n\n";
+
+    // test llama_chat_format_single for system message
+    printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
     std::vector<llama_chat_msg> chat2;
+    llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
+
+    auto fmt_sys = [&](std::string tmpl) {
+        auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
+        printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
+        printf("-------------------------\n", output.c_str());
+        return output;
+    };
+    assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
+    assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+    assert(fmt_sys("gemma")  == ""); // for gemma, system message is merged with user message
+    assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+
+
+    // test llama_chat_format_single for user message
+    printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
     chat2.push_back({"system", "You are a helpful assistant"});
     chat2.push_back({"user", "Hello"});
     chat2.push_back({"assistant", "I am assistant"});
@@ -147,12 +164,13 @@ int main(void) {
 
     auto fmt_single = [&](std::string tmpl) {
         auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
-        std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
+        printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
+        printf("-------------------------\n", output.c_str());
         return output;
     };
     assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
     assert(fmt_single("llama2") == "[INST] How are you [/INST]");
-    assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+    assert(fmt_single("gemma")  == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
 
     return 0;