]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
add alias for chat template (#5858)
authorXuan Son Nguyen <redacted>
Mon, 4 Mar 2024 11:22:08 +0000 (12:22 +0100)
committerGitHub <redacted>
Mon, 4 Mar 2024 11:22:08 +0000 (12:22 +0100)
examples/server/server.cpp
llama.cpp

index 0ca388f47db7bd2b86c01c6e60a788d1c95cac31..208edd571cb0ed5c459e49b75a794aa50be930ad 100644 (file)
@@ -413,7 +413,7 @@ struct llama_server_context
         int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
         if (res < 0) {
             LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
-            sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
+            sparams.chat_template = "chatml";
         }
     }
 
index c1f015791e826e303bc8d39fd6a9b5f1b8a86125..de579d9e372b4e1e5d180e48c208d02e58370013 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -13282,7 +13282,7 @@ static int32_t llama_chat_apply_template_internal(
     std::string & dest, bool add_ass) {
     // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
     std::stringstream ss;
-    if (tmpl.find("<|im_start|>") != std::string::npos) {
+    if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
         // chatml template
         for (auto message : chat) {
             ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@@ -13290,7 +13290,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl.find("[INST]") != std::string::npos) {
+    } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
         // llama2 template and its variants
         // [variant] support system message
         bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
@@ -13325,7 +13325,7 @@ static int32_t llama_chat_apply_template_internal(
             }
         }
         // llama2 templates seem to not care about "add_generation_prompt"
-    } else if (tmpl.find("<|user|>") != std::string::npos) {
+    } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
         // zephyr template
         for (auto message : chat) {
             ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@@ -13333,7 +13333,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl.find("bos_token + message['role']") != std::string::npos) {
+    } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
         // mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
         for (auto message : chat) {
             std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@@ -13342,7 +13342,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<s>assistant\n";
         }
-    } else if (tmpl.find("<start_of_turn>") != std::string::npos) {
+    } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
         // google/gemma-7b-it
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -13389,7 +13389,7 @@ LLAMA_API int32_t llama_chat_apply_template(
         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
         if (res < 0) {
             // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
+            curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
         } else {
             curr_tmpl = std::string(model_template.data(), model_template.size());
         }