]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fallback to chatml, add AlphaMonarch chat template (#5628)
authorXuan Son Nguyen <redacted>
Thu, 22 Feb 2024 08:33:24 +0000 (09:33 +0100)
committerGitHub <redacted>
Thu, 22 Feb 2024 08:33:24 +0000 (10:33 +0200)
* server: fallback to chatml

* add new chat template

* server: add AlphaMonarch to test chat template

* server: only check model template if there is no custom tmpl

* remove TODO

examples/server/server.cpp
llama.cpp
tests/test-chat-template.cpp

index c84719a0d15d0739f450c21614f643010c02c333..369121e885b27e4f7c98f6eff960f1d53f106093 100644 (file)
@@ -400,6 +400,16 @@ struct llama_server_context
         return true;
     }
 
+    void validate_model_chat_template(server_params & sparams) {
+        llama_chat_message chat[] = {{"user", "test"}};
+        std::vector<char> buf(1);
+        int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
+        if (res < 0) {
+            LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
+            sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
+        }
+    }
+
     void initialize() {
         // create slots
         all_slots_are_idle = true;
@@ -2752,6 +2762,11 @@ int main(int argc, char **argv)
         LOG_INFO("model loaded", {});
     }
 
+    if (sparams.chat_template.empty()) { // custom chat template is not supplied
+        // check if the template comes with the model is supported by us
+        llama.validate_model_chat_template(sparams);
+    }
+
     // Middleware for API key validation
     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
         // If API key is not set, skip validation
index 9cae8c761f3acec8c3875c9f6c235f061db984ca..055b57e3187f2c84d48e131d2e2bf3b642f2ebe2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -12773,6 +12773,15 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
+    } else if (tmpl.find("bos_token + message['role']") != std::string::npos) {
+        // mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
+        for (auto message : chat) {
+            std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
+            ss << bos << message->role << "\n" << message->content << "</s>\n";
+        }
+        if (add_ass) {
+            ss << "<s>assistant\n";
+        }
     } else {
         // template not supported
         return -1;
index 9830650d4f8dda593079627d77620294ad1e87de..d02b39e1449479afca74240fa86dd2fda19695e6 100644 (file)
@@ -27,12 +27,20 @@ int main(void) {
         "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
         // bofenghuang/vigogne-2-70b-chat
         "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+        // mlabonne/AlphaMonarch-7B
+        "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
     };
-    std::vector<std::string> expected_substr = {
-        "<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
-        "[/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
-        "</s><s>[INST] Who are you [/INST]    I am an assistant    </s><s>[INST] Another question [/INST]",
-        "[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
+    std::vector<std::string> expected_output = {
+        // teknium/OpenHermes-2.5-Mistral-7B
+        "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
+        // mistralai/Mistral-7B-Instruct-v0.2
+        "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
+        // TheBloke/FusionNet_34Bx2_MoE-AWQ
+        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST]    I am an assistant    </s><s>[INST] Another question [/INST]",
+        // bofenghuang/vigogne-2-70b-chat
+        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
+        // mlabonne/AlphaMonarch-7B
+        "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
     };
     std::vector<char> formatted_chat(1024);
     int32_t res;
@@ -43,7 +51,7 @@ int main(void) {
 
     for (size_t i = 0; i < templates.size(); i++) {
         std::string custom_template = templates[i];
-        std::string substr = expected_substr[i];
+        std::string expected = expected_output[i];
         formatted_chat.resize(1024);
         res = llama_chat_apply_template(
             nullptr,
@@ -57,8 +65,7 @@ int main(void) {
         formatted_chat.resize(res);
         std::string output(formatted_chat.data(), formatted_chat.size());
         std::cout << output << "\n-------------------------\n";
-        // expect the "formatted_chat" to contain pre-defined strings
-        assert(output.find(substr) != std::string::npos);
+        assert(output == expected);
     }
     return 0;
 }