]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix new line issue with chat template, disable template when in-prefix/suffix is...
authorXuan Son Nguyen <redacted>
Sun, 30 Jun 2024 18:27:13 +0000 (20:27 +0200)
committerGitHub <redacted>
Sun, 30 Jun 2024 18:27:13 +0000 (20:27 +0200)
* preserve new line llama_chat_format_single

* disable chat template if in-prefix/suffix is set

* remove redundant change

common/common.cpp
common/common.h
examples/main/main.cpp
tests/test-chat-template.cpp

index 6a00d25be1316a1128a5dfc06714c01d215d407e..5a0d0ee038123933762c47fd331685d3ab0b5d50 100644 (file)
@@ -1014,16 +1014,19 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     }
     if (arg == "--in-prefix-bos") {
         params.input_prefix_bos = true;
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--in-prefix") {
         CHECK_ARG
         params.input_prefix = argv[i];
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--in-suffix") {
         CHECK_ARG
         params.input_suffix = argv[i];
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--spm-infill") {
@@ -1406,7 +1409,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                                                         "halt generation at PROMPT, return control in interactive mode\n"
                                                                         "can be specified more than once for multiple prompts" });
     options.push_back({ "main",        "-sp,   --special",              "special tokens output enabled (default: %s)", params.special ? "true" : "false" });
-    options.push_back({ "main",        "-cnv,  --conversation",         "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" });
+    options.push_back({ "main",        "-cnv,  --conversation",         "run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)", params.conversation ? "true" : "false" });
     options.push_back({ "main infill", "-i,    --interactive",          "run in interactive mode (default: %s)", params.interactive ? "true" : "false" });
     options.push_back({ "main infill", "-if,   --interactive-first",    "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" });
     options.push_back({ "main infill", "-mli,  --multiline-input",      "allows you to write or paste multiple lines without ending each in '\\'" });
@@ -2668,12 +2671,19 @@ std::string llama_chat_format_single(const struct llama_model * model,
         const std::vector<llama_chat_msg> & past_msg,
         const llama_chat_msg & new_msg,
         bool add_ass) {
+    std::ostringstream ss;
     auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
     std::vector<llama_chat_msg> chat_new(past_msg);
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
     chat_new.push_back(new_msg);
     auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
-    auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
-    return formatted;
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
 }
 
 std::string llama_chat_format_example(const struct llama_model * model,
index d6cb814b990e9ba8049377a11dcb26e85f194c16..627b7ed854757eb33252895c13152ed6c563acb6 100644 (file)
@@ -200,6 +200,7 @@ struct gpt_params {
     std::string public_path   = "";
     std::string chat_template = "";
     std::string system_prompt = "";
+    bool enable_chat_template = true;
 
     std::vector<std::string> api_keys;
 
index 1114073b84370d015a07869801b9dc91d504d767..d512953b9635c8a5844ed60db58301b3acede4ce 100644 (file)
@@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd_inp;
 
     {
-        auto prompt = params.conversation
+        auto prompt = (params.conversation && params.enable_chat_template)
             ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
             : params.prompt;
         if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -810,7 +810,9 @@ int main(int argc, char ** argv) {
                         is_antiprompt = true;
                     }
 
-                    chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+                    if (params.enable_chat_template) {
+                        chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+                    }
                     is_interacting = true;
                     printf("\n");
                 }
@@ -872,12 +874,13 @@ int main(int argc, char ** argv) {
                         string_process_escapes(buffer);
                     }
 
-                    std::string user_inp = params.conversation
+                    bool format_chat = params.conversation && params.enable_chat_template;
+                    std::string user_inp = format_chat
                         ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
                         : std::move(buffer);
                     // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
                     const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
-                    const auto line_inp = ::llama_tokenize(ctx, user_inp,            false, params.conversation);
+                    const auto line_inp = ::llama_tokenize(ctx, user_inp,            false, format_chat);
                     const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
 
                     LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
index b154038b2d5c03b7816ac8c8c90f62c06891f9e2..03f5369109716f07b62599dc1fb165313f085975 100644 (file)
@@ -142,9 +142,9 @@ int main(void) {
         std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
         return output;
     };
-    assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+    assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
     assert(fmt_single("llama2") == "[INST] How are you [/INST]");
-    assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+    assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
 
     return 0;