]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add llama2 chat template (#5425)
authorXuan Son Nguyen <redacted>
Sun, 11 Feb 2024 10:16:22 +0000 (11:16 +0100)
committerGitHub <redacted>
Sun, 11 Feb 2024 10:16:22 +0000 (12:16 +0200)
* server: add mistral chat template

* server: fix typo

* server: rename template mistral to llama2

* server: format_llama2: remove BOS

* server: validate "--chat-template" argument

* server: clean up using_chatml variable

Co-authored-by: Jared Van Bortel <redacted>
---------

Co-authored-by: Jared Van Bortel <redacted>
examples/server/oai.hpp
examples/server/server.cpp
examples/server/utils.hpp

index 43410f803d469326f5098dd6be4fc45cde155150..2eca8a9fb45600ffe033c7ca025a933737af3469 100644 (file)
 using json = nlohmann::json;
 
 inline static json oaicompat_completion_params_parse(
-    const json &body /* openai api json semantics */)
+    const json &body, /* openai api json semantics */
+    const std::string &chat_template)
 {
     json llama_params;
+    std::string formatted_prompt = chat_template == "chatml"
+        ? format_chatml(body["messages"])  // OpenAI 'messages' to chatml (with <|im_start|>,...)
+        : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
 
     llama_params["__oaicompat"] = true;
 
@@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse(
     // https://platform.openai.com/docs/api-reference/chat/create
     llama_sampling_params default_sparams;
     llama_params["model"]             = json_value(body, "model", std::string("unknown"));
-    llama_params["prompt"]            = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
+    llama_params["prompt"]            = formatted_prompt;
     llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false);
     llama_params["temperature"]       = json_value(body, "temperature", 0.0);
     llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k);
index 8d668f7989eedcb066f7fca88d4a82480ee2de76..4d212f1f0e65cafcba10a65df38443a7e2ac6123 100644 (file)
@@ -36,6 +36,7 @@ struct server_params
     std::string hostname = "127.0.0.1";
     std::vector<std::string> api_keys;
     std::string public_path = "examples/server/public";
+    std::string chat_template = "chatml";
     int32_t port = 8080;
     int32_t read_timeout = 600;
     int32_t write_timeout = 600;
@@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("                            types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
     printf("  -gan N, --grp-attn-n N    set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
     printf("  -gaw N, --grp-attn-w N    set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
+    printf("  --chat-template FORMAT_NAME");
+    printf("                            set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
     printf("\n");
 }
 
@@ -2290,6 +2293,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             log_set_target(stdout);
             LOG_INFO("logging to file is disabled.", {});
         }
+        else if (arg == "--chat-template")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            if (value != "chatml" && value != "llama2") {
+                fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
+                invalid_param = true;
+                break;
+            }
+            sparams.chat_template = value;
+        }
         else if (arg == "--override-kv")
         {
             if (++i >= argc) {
@@ -2743,13 +2761,13 @@ int main(int argc, char **argv)
 
 
     // TODO: add mount point without "/v1" prefix -- how?
-    svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
+    svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
             {
                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 if (!validate_api_key(req, res)) {
                     return;
                 }
-                json data = oaicompat_completion_params_parse(json::parse(req.body));
+                json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
 
                 const int task_id = llama.queue_tasks.get_new_id();
                 llama.queue_results.add_waiting_task_id(task_id);
index 70cce0721be08234a677f9f8dad7b475437a1a39..5485489627d5d073adc448270064de1a53b33fe8 100644 (file)
@@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v
         : default_value;
 }
 
+inline std::string format_llama2(std::vector<json> messages)
+{
+    std::ostringstream output;
+    bool is_inside_turn = false;
+
+    for (auto it = messages.begin(); it != messages.end(); ++it) {
+        if (!is_inside_turn) {
+            output << "[INST] ";
+        }
+        std::string role    = json_value(*it, "role", std::string("user"));
+        std::string content = json_value(*it, "content", std::string(""));
+        if (role == "system") {
+            output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
+            is_inside_turn = true;
+        } else if (role == "user") {
+            output << content << " [/INST]";
+            is_inside_turn = true;
+        } else {
+            output << " " << content << " </s>";
+            is_inside_turn = false;
+        }
+    }
+
+    LOG_VERBOSE("format_llama2", {{"text", output.str()}});
+
+    return output.str();
+}
+
 inline std::string format_chatml(std::vector<json> messages)
 {
     std::ostringstream chatml_msgs;
@@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages)
 
     chatml_msgs << "<|im_start|>assistant" << '\n';
 
+    LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
+
     return chatml_msgs.str();
 }