]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add llama_chat_apply_template() (#5538)
authorXuan Son Nguyen <redacted>
Mon, 19 Feb 2024 08:23:37 +0000 (09:23 +0100)
committerGitHub <redacted>
Mon, 19 Feb 2024 08:23:37 +0000 (10:23 +0200)
* llama: add llama_chat_apply_template

* test-chat-template: remove dedundant vector

* chat_template: do not use std::string for buffer

* add clarification for llama_chat_apply_template

* llama_chat_apply_template: add zephyr template

* llama_chat_apply_template: correct docs

* llama_chat_apply_template: use term "chat" everywhere

* llama_chat_apply_template: change variable name to "tmpl"

Makefile
llama.cpp
llama.h
tests/CMakeLists.txt
tests/test-chat-template.cpp [new file with mode: 0644]

index f5f6d32a7673584fe1ccf48482b52edff00b2f48..59352eb5325975861ca9530050ce8d8b5ad0e680 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -867,3 +867,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te
 tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS)
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
+
+tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
+       $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
+       $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
index 5cfebb3b1c99c685d92af33952b0e3776db58ad7..143870645d6371012f0beee1b0e252400a2d028f 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
     return 0;
 }
 
+// trim whitespace from the beginning and end of a string
+static std::string trim(const std::string & str) {
+    size_t start = 0;
+    size_t end = str.size();
+    while (start < end && isspace(str[start])) {
+        start += 1;
+    }
+    while (end > start && isspace(str[end - 1])) {
+        end -= 1;
+    }
+    return str.substr(start, end - start);
+}
+
+// Simple version of "llama_apply_chat_template" that only works with strings
+// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
+static int32_t llama_chat_apply_template_internal(
+    const std::string & tmpl,
+    const std::vector<const llama_chat_message *> & chat, 
+    std::string & dest, bool add_ass) {
+    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
+    std::stringstream ss;
+    if (tmpl.find("<|im_start|>") != std::string::npos) {
+        // chatml template
+        for (auto message : chat) {
+            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
+        }
+        if (add_ass) {
+            ss << "<|im_start|>assistant\n";
+        }
+    } else if (tmpl.find("[INST]") != std::string::npos) {
+        // llama2 template and its variants
+        // [variant] support system message
+        bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
+        // [variant] space before + after response
+        bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
+        // [variant] add BOS inside history
+        bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
+        // [variant] trim spaces from the input message
+        bool strip_message = tmpl.find("content.strip()") != std::string::npos;
+        // construct the prompt
+        bool is_inside_turn = true; // skip BOS at the beginning
+        ss << "[INST] ";
+        for (auto message : chat) {
+            std::string content = strip_message ? trim(message->content) : message->content;
+            std::string role(message->role);
+            if (!is_inside_turn) {
+                is_inside_turn = true;
+                ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
+            }
+            if (role == "system") {
+                if (support_system_message) {
+                    ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
+                } else {
+                    // if the model does not support system message, we still include it in the first message, but without <<SYS>>
+                    ss << content << "\n";
+                }
+            } else if (role == "user") {
+                ss << content << " [/INST]";
+            } else {
+                ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
+                is_inside_turn = false;
+            }
+        }
+        // llama2 templates seem to not care about "add_generation_prompt"
+    } else if (tmpl.find("<|user|>") != std::string::npos) {
+        // zephyr template
+        for (auto message : chat) {
+            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
+    } else {
+        // template not supported
+        return -1;
+    }
+    dest = ss.str();
+    return dest.size();
+}
+
+LLAMA_API int32_t llama_chat_apply_template(
+                const struct llama_model * model,
+                              const char * tmpl,
+         const struct llama_chat_message * chat,
+                                  size_t   n_msg,
+                                    bool   add_ass,
+                                    char * buf,
+                                 int32_t   length) {
+    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
+    if (tmpl == nullptr) {
+        GGML_ASSERT(model != nullptr);
+        // load template from model
+        std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
+        std::string template_key = "tokenizer.chat_template";
+        int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
+        if (res < 0) {
+            // worst case: there is no information about template, we will use chatml by default
+            curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
+        } else {
+            curr_tmpl = std::string(model_template.data(), model_template.size());
+        }
+    }
+    // format the chat to string
+    std::vector<const llama_chat_message *> chat_vec;
+    chat_vec.resize(n_msg);
+    for (size_t i = 0; i < n_msg; i++) {
+        chat_vec[i] = &chat[i];
+    }
+    std::string formatted_chat;
+    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
+    if (res < 0) {
+        return res;
+    }
+    strncpy(buf, formatted_chat.c_str(), length);
+    return res;
+}
+
 struct llama_timings llama_get_timings(struct llama_context * ctx) {
     struct llama_timings result = {
         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
diff --git a/llama.h b/llama.h
index 5a97abcc94d7cb37b7e8140a189d1f64ebfca26f..77a84c18a69cfbc56841683ac0de3e9e7cc9d5f7 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -305,6 +305,12 @@ extern "C" {
         int32_t n_eval;
     };
 
+    // used in chat template
+    typedef struct llama_chat_message {
+        const char * role;
+        const char * content;
+    } llama_chat_message;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -699,6 +705,25 @@ extern "C" {
                                   char * buf,
                                int32_t   length);
 
+    /// Apply chat template. Inspired by hf apply_chat_template() on python.
+    /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
+    /// NOTE: This function only support some known jinja templates. It is not a jinja parser.
+    /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
+    /// @param chat Pointer to a list of multiple llama_chat_message
+    /// @param n_msg Number of llama_chat_message in this chat
+    /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
+    /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
+    /// @param length The size of the allocated buffer
+    /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
+    LLAMA_API int32_t llama_chat_apply_template(
+              const struct llama_model * model,
+                            const char * tmpl,
+       const struct llama_chat_message * chat,
+                                size_t   n_msg,
+                                  bool   add_ass,
+                                  char * buf,
+                               int32_t   length);
+
     //
     // Grammar
     //
index 3e40a78cdeac9ebad40487ad9c38e03ad2fcc013..10326d531e9fb7e7c296cdcd523a05215f50b5c5 100644 (file)
@@ -28,6 +28,7 @@ endfunction()
 llama_build_and_test_executable(test-quantize-fns.cpp)
 llama_build_and_test_executable(test-quantize-perf.cpp)
 llama_build_and_test_executable(test-sampling.cpp)
+llama_build_and_test_executable(test-chat-template.cpp)
 
 llama_build_executable(test-tokenizer-0-llama.cpp)
 llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
new file mode 100644 (file)
index 0000000..9830650
--- /dev/null
@@ -0,0 +1,64 @@
+#include <iostream>
+#include <string>
+#include <vector>
+#include <sstream>
+
+#undef NDEBUG
+#include <cassert>
+
+#include "llama.h"
+
+int main(void) {
+    llama_chat_message conversation[] = {
+        {"system", "You are a helpful assistant"},
+        {"user", "Hello"},
+        {"assistant", "Hi there"},
+        {"user", "Who are you"},
+        {"assistant", "   I am an assistant   "},
+        {"user", "Another question"},
+    };
+    size_t message_count = 6;
+    std::vector<std::string> templates = {
+        // teknium/OpenHermes-2.5-Mistral-7B
+        "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
+        // mistralai/Mistral-7B-Instruct-v0.2
+        "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
+        // TheBloke/FusionNet_34Bx2_MoE-AWQ
+        "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
+        // bofenghuang/vigogne-2-70b-chat
+        "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+    };
+    std::vector<std::string> expected_substr = {
+        "<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
+        "[/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
+        "</s><s>[INST] Who are you [/INST]    I am an assistant    </s><s>[INST] Another question [/INST]",
+        "[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
+    };
+    std::vector<char> formatted_chat(1024);
+    int32_t res;
+
+    // test invalid chat template
+    res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
+    assert(res < 0);
+
+    for (size_t i = 0; i < templates.size(); i++) {
+        std::string custom_template = templates[i];
+        std::string substr = expected_substr[i];
+        formatted_chat.resize(1024);
+        res = llama_chat_apply_template(
+            nullptr,
+            custom_template.c_str(),
+            conversation,
+            message_count,
+            true,
+            formatted_chat.data(),
+            formatted_chat.size()
+        );
+        formatted_chat.resize(res);
+        std::string output(formatted_chat.data(), formatted_chat.size());
+        std::cout << output << "\n-------------------------\n";
+        // expect the "formatted_chat" to contain pre-defined strings
+        assert(output.find(substr) != std::string::npos);
+    }
+    return 0;
+}