]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add chatml fallback for cpp `llama_chat_apply_template` (#8160)
authorXuan Son Nguyen <redacted>
Thu, 27 Jun 2024 16:14:19 +0000 (18:14 +0200)
committerGitHub <redacted>
Thu, 27 Jun 2024 16:14:19 +0000 (18:14 +0200)
* add chatml fallback for cpp `llama_chat_apply_template`

* remove redundant code

common/common.cpp
common/common.h

index 70349ad70891c7154d1a9f89bc991657537eb69c..57d03a5789edd614c5fa66f6207704b7b9887397 100644 (file)
@@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
         const std::vector<llama_chat_msg> & msgs,
         bool add_ass) {
     int alloc_size = 0;
+    bool fallback = false; // indicate if we must fallback to default chatml
     std::vector<llama_chat_message> chat;
     for (auto & msg : msgs) {
         chat.push_back({msg.role.c_str(), msg.content.c_str()});
@@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
     // run the first time to get the total output length
     int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
 
+    // error: chat template is not supported
+    if (res < 0) {
+        if (ptr_tmpl != nullptr) {
+            // if the custom "tmpl" is not supported, we throw an error
+            // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+            throw std::runtime_error("this custom template is not supported");
+        } else {
+            // If the built-in template is not supported, we default to chatml
+            res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+            fallback = true;
+        }
+    }
+
     // if it turns out that our buffer is too small, we resize it
     if ((size_t) res > buf.size()) {
         buf.resize(res);
-        res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+        res = llama_chat_apply_template(
+            fallback ? nullptr : model,
+            fallback ? "chatml" : ptr_tmpl,
+            chat.data(), chat.size(), add_ass, buf.data(), buf.size());
     }
 
     std::string formatted_chat(buf.data(), res);
index c541204f6743b4c541a1e706a85cd9d31e9ac994..0486ba3800ed754f95b5b89a393ae95dcdeb7a10 100644 (file)
@@ -380,6 +380,8 @@ struct llama_chat_msg {
 bool llama_chat_verify_template(const std::string & tmpl);
 
 // CPP wrapper for llama_chat_apply_template
+// If the built-in template is not supported, we default to chatml
+// If the custom "tmpl" is not supported, we throw an error
 std::string llama_chat_apply_template(const struct llama_model * model,
         const std::string & tmpl,
         const std::vector<llama_chat_msg> & chat,