]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add chat template support for llama-cli (#8068)
authorXuan Son Nguyen <redacted>
Tue, 25 Jun 2024 11:56:49 +0000 (13:56 +0200)
committerGitHub <redacted>
Tue, 25 Jun 2024 11:56:49 +0000 (21:56 +1000)
* add chat template support for llama-cli

* add help message

* server: simplify format_chat

* more consistent naming

* improve

* add llama_chat_format_example

* fix server

* code style

* code style

* Update examples/main/main.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
examples/main/main.cpp
examples/server/server.cpp
examples/server/utils.hpp
llama.cpp
tests/test-chat-template.cpp

index 0ca7b4430f765a1124f707abd3809307fe52bc90..da6db4dc6a09cd8b226e5a4de30817e8e411df5c 100644 (file)
@@ -1444,7 +1444,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "main",        "       --cfg-negative-prompt-file FNAME",
                                                                         "negative prompt file to use for guidance" });
     options.push_back({ "main",        "       --cfg-scale N",          "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
-
+    options.push_back({ "main",        "       --chat-template JINJA_TEMPLATE",
+                                                                        "set custom jinja chat template (default: template taken from model's metadata)\n"
+                                                                        "only commonly used templates are accepted:\n"
+                                                                        "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
     options.push_back({ "grammar" });
     options.push_back({ "*",           "       --grammar GRAMMAR",      "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
     options.push_back({ "*",           "       --grammar-file FNAME",   "file to read grammar from" });
@@ -2604,12 +2607,67 @@ bool llama_should_add_bos_token(const llama_model * model) {
     return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
 }
 
+//
+// Chat template utils
+//
+
 bool llama_chat_verify_template(const std::string & tmpl) {
     llama_chat_message chat[] = {{"user", "test"}};
     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
     return res >= 0;
 }
 
+std::string llama_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & msgs,
+        bool add_ass) {
+    int alloc_size = 0;
+    std::vector<llama_chat_message> chat;
+    for (auto & msg : msgs) {
+        chat.push_back({msg.role.c_str(), msg.content.c_str()});
+        alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
+    }
+
+    const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
+    std::vector<char> buf(alloc_size);
+
+    // run the first time to get the total output length
+    int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+
+    // if it turns out that our buffer is too small, we resize it
+    if ((size_t) res > buf.size()) {
+        buf.resize(res);
+        res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+    }
+
+    std::string formatted_chat(buf.data(), res);
+    return formatted_chat;
+}
+
+std::string llama_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & past_msg,
+        const llama_chat_msg & new_msg,
+        bool add_ass) {
+    auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
+    std::vector<llama_chat_msg> chat_new(past_msg);
+    chat_new.push_back(new_msg);
+    auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
+    auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return formatted;
+}
+
+std::string llama_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl) {
+    std::vector<llama_chat_msg> msgs = {
+        {"system",    "You are a helpful assistant"},
+        {"user",      "Hello"},
+        {"assistant", "Hi there"},
+        {"user",      "How are you?"},
+    };
+    return llama_chat_apply_template(model, tmpl, msgs, true);
+}
+
 //
 // KV cache utils
 //
index a5c738f8b643f87787f9ea4ec2dafcd01cef7887..de90eec5113f79416b116b75e089d4df6c86910b 100644 (file)
@@ -365,9 +365,32 @@ bool llama_should_add_bos_token(const llama_model * model);
 // Chat template utils
 //
 
+// same with llama_chat_message, but uses std::string
+struct llama_chat_msg {
+    std::string role;
+    std::string content;
+};
+
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 bool llama_chat_verify_template(const std::string & tmpl);
 
+// CPP wrapper for llama_chat_apply_template
+std::string llama_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & chat,
+        bool add_ass);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string llama_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & past_msg,
+        const llama_chat_msg & new_msg,
+        bool add_ass);
+
+// Returns an example of formatted chat
+std::string llama_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl);
+
 //
 // KV cache utils
 //
index b97b7b7937f02ac846ddc4c4dd5a85b843c7b3f2..cfaf6a6e8ba4a364a816d1cf19aa8d436fbd863a 100644 (file)
@@ -39,12 +39,12 @@ static std::ostringstream       * g_output_ss;
 static std::vector<llama_token> * g_output_tokens;
 static bool is_interacting = false;
 
-static bool file_exists(const std::string &path) {
+static bool file_exists(const std::string & path) {
     std::ifstream f(path.c_str());
     return f.good();
 }
 
-static bool file_is_empty(const std::string &path) {
+static bool file_is_empty(const std::string & path) {
     std::ifstream f;
     f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
     f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
     LOG_TEE("%s", text);
 }
 
+static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
+    llama_chat_msg new_msg{role, content};
+    auto formatted = llama_chat_format_single(
+        model, g_params->chat_template, chat_msgs, new_msg, role == "user");
+    chat_msgs.push_back({role, content});
+    return formatted;
+}
+
 int main(int argc, char ** argv) {
     gpt_params params;
     g_params = &params;
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) {
     llama_model * model;
     llama_context * ctx;
     llama_context * ctx_guidance = NULL;
+    std::vector<llama_chat_msg> chat_msgs;
     g_model = &model;
     g_ctx = &ctx;
 
@@ -215,6 +224,8 @@ int main(int argc, char ** argv) {
                 __func__, n_ctx_train, n_ctx);
     }
 
+    LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
+
     // print system information
     {
         LOG_TEE("\n");
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd_inp;
 
-    if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
-        LOG("tokenize the prompt\n");
-        embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
-    } else {
-        LOG("use session tokens\n");
-        embd_inp = session_tokens;
-    }
+    {
+        auto prompt = params.conversation
+            ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
+            : params.prompt;
+        if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
+            LOG("tokenize the prompt\n");
+            embd_inp = ::llama_tokenize(ctx, prompt, true, true);
+        } else {
+            LOG("use session tokens\n");
+            embd_inp = session_tokens;
+        }
 
-    LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
-    LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+        LOG("prompt: \"%s\"\n", log_tostr(prompt));
+        LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+    }
 
     // Should not run without any tokens
     if (embd_inp.empty()) {
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
     std::vector<int>   input_tokens;  g_input_tokens  = &input_tokens;
     std::vector<int>   output_tokens; g_output_tokens = &output_tokens;
     std::ostringstream output_ss;     g_output_ss     = &output_ss;
+    std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
 
     // the first thing we will do is to output the prompt, so set color accordingly
     console::set_display(console::prompt);
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
                         is_antiprompt = true;
                     }
 
+                    chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
                     is_interacting = true;
                     printf("\n");
                 }
             }
 
+            // if current token is not EOG, we add it to current assistant message
+            if (params.conversation) {
+                auto id = llama_sampling_last(ctx_sampling);
+                assistant_ss << llama_token_to_piece(ctx, id, false);
+            }
+
             if (n_past > 0 && is_interacting) {
                 LOG("waiting for user input\n");
 
@@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
                         string_process_escapes(buffer);
                     }
 
+                    std::string user_inp = params.conversation
+                        ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
+                        : std::move(buffer);
+                    // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
                     const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
-                    const auto line_inp = ::llama_tokenize(ctx, buffer,              false, false);
+                    const auto line_inp = ::llama_tokenize(ctx, user_inp,            false, params.conversation);
                     const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
 
                     LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
                         output_ss << llama_token_to_piece(ctx, token);
                     }
 
+                    // reset assistant message
+                    assistant_ss.str("");
+
                     n_remain -= line_inp.size();
                     LOG("n_remain: %d\n", n_remain);
                 } else {
index f9a86961f9c8e6d9f2ff2e65e129e2f1ccf33ba3..ae768097baa0e5a3f5941f1eaa0e1830a7acd49d 100644 (file)
@@ -2606,17 +2606,9 @@ int main(int argc, char ** argv) {
 
     // print sample chat example to make it clear which template is used
     {
-        json chat;
-        chat.push_back({{"role", "system"},    {"content", "You are a helpful assistant"}});
-        chat.push_back({{"role", "user"},      {"content", "Hello"}});
-        chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
-        chat.push_back({{"role", "user"},      {"content", "How are you?"}});
-
-        const std::string chat_example = format_chat(ctx_server.model, params.chat_template, chat);
-
         LOG_INFO("chat template", {
-            {"chat_example", chat_example},
-            {"built_in", params.chat_template.empty()},
+            {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
+            {"built_in",     params.chat_template.empty()},
         });
     }
 
index 63fde9c9faabe3cd68ce1692245706c6aa5dcce7..7ef2a519a10c76f22fe47219aecee366aafc6d82 100644 (file)
@@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin
 
 // Format given chat. If tmpl is empty, we take the template from model metadata
 inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
-    size_t alloc_size = 0;
-    // vector holding all allocated string to be passed to llama_chat_apply_template
-    std::vector<std::string> str(messages.size() * 2);
-    std::vector<llama_chat_message> chat(messages.size());
+    std::vector<llama_chat_msg> chat;
 
     for (size_t i = 0; i < messages.size(); ++i) {
         const auto & curr_msg = messages[i];
-        str[i*2 + 0]    = json_value(curr_msg, "role",    std::string(""));
-        str[i*2 + 1]    = json_value(curr_msg, "content", std::string(""));
-        alloc_size     += str[i*2 + 1].length();
-        chat[i].role    = str[i*2 + 0].c_str();
-        chat[i].content = str[i*2 + 1].c_str();
+        std::string role    = json_value(curr_msg, "role",    std::string(""));
+        std::string content = json_value(curr_msg, "content", std::string(""));
+        chat.push_back({role, content});
     }
 
-    const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
-    std::vector<char> buf(alloc_size * 2);
-
-    // run the first time to get the total output length
-    int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
-
-    // if it turns out that our buffer is too small, we resize it
-    if ((size_t) res > buf.size()) {
-        buf.resize(res);
-        res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
-    }
-
-    const std::string formatted_chat(buf.data(), res);
-
+    auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
     LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
-
     return formatted_chat;
 }
 
index 49bc93c028a2a66ef34498a0c80ee4400a9bd934..33e6cb7229aab2489565d258811505f9d7f4c1fd 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -18818,10 +18818,10 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
+    } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
         // llama2 template and its variants
         // [variant] support system message
-        bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
+        bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
         // [variant] space before + after response
         bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
         // [variant] add BOS inside history
index cef9a650bdfdfc09bd9b920254d73f9b91c6792a..d19ba8633e8c23ea01134c539effbc6893e77d24 100644 (file)
@@ -7,6 +7,7 @@
 #include <cassert>
 
 #include "llama.h"
+#include "common.h"
 
 int main(void) {
     llama_chat_message conversation[] = {
@@ -119,5 +120,24 @@ int main(void) {
         std::cout << output << "\n-------------------------\n";
         assert(output == expected);
     }
+
+    // test llama_chat_format_single
+    std::cout << "\n\n=== llama_chat_format_single ===\n\n";
+    std::vector<llama_chat_msg> chat2;
+    chat2.push_back({"system", "You are a helpful assistant"});
+    chat2.push_back({"user", "Hello"});
+    chat2.push_back({"assistant", "I am assistant"});
+    llama_chat_msg new_msg{"user", "How are you"};
+
+    auto fmt_single = [&](std::string tmpl) {
+        auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
+        std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
+        return output;
+    };
+    assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+    assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+    assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+    assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+
     return 0;
 }