]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add enum for built-in chat templates (#10623)
authorXuan Son Nguyen <redacted>
Mon, 2 Dec 2024 21:10:19 +0000 (22:10 +0100)
committerGitHub <redacted>
Mon, 2 Dec 2024 21:10:19 +0000 (22:10 +0100)
* llama : add enum for supported chat templates

* use "built-in" instead of "supported"

* arg: print list of built-in templates

* fix test

* update server README

common/arg.cpp
examples/server/README.md
include/llama.h
src/llama.cpp
tests/test-chat-template.cpp

index 32d9a964c1716dbbabe2b3ca968a2c298bc23fcc..078c7538490c4eab9432caf7257f30271087cacd 100644 (file)
@@ -348,6 +348,18 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
     return true;
 }
 
+static std::string list_builtin_chat_templates() {
+    std::vector<const char *> supported_tmpl;
+    int32_t res = llama_chat_builtin_templates(nullptr, 0);
+    supported_tmpl.resize(res);
+    res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
+    std::ostringstream msg;
+    for (auto & tmpl : supported_tmpl) {
+        msg << tmpl << (&tmpl == &supported_tmpl.back() ? "" : ", ");
+    }
+    return msg.str();
+}
+
 common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
     // load dynamic backends
     ggml_backend_load_all();
@@ -1814,9 +1826,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--chat-template"}, "JINJA_TEMPLATE",
-        "set custom jinja chat template (default: template taken from model's metadata)\n"
-        "if suffix/prefix are specified, template will be disabled\n"
-        "only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
+        string_format(
+            "set custom jinja chat template (default: template taken from model's metadata)\n"
+            "if suffix/prefix are specified, template will be disabled\n"
+            "list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
+        ),
         [](common_params & params, const std::string & value) {
             if (!common_chat_verify_template(value)) {
                 throw std::runtime_error(string_format(
index aa99d06f957312cc7f544da69e0479b507a68c16..3f0d45e5bed1b91b68eb1e2ac708d584b383b865 100644 (file)
@@ -69,6 +69,8 @@ The project is under active development, and we are [looking for feedback and co
 | `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
 | `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)<br/>(env: LLAMA_ARG_NO_MMAP) |
 | `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggerganov/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
+| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
+| `--list-devices` | print list of available devices and exit |
 | `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
 | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
 | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
@@ -158,9 +160,16 @@ The project is under active development, and we are [looking for feedback and co
 | `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
 | `--no-slots` | disables slots monitoring endpoint<br/>(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) |
 | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) |
-| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted:<br/>https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
+| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>list of built-in templates:<br/>chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
+| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16) |
+| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 5) |
+| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.9) |
+| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model) |
+| `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
+| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model |
+| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused) |
 
 
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
index ab5e376e6c7f21f025a4bdce61690a3bc826aa36..439e0ff0c7e0c8696035373bc577ef42fa0a228a 100644 (file)
@@ -990,6 +990,9 @@ extern "C" {
                                   char * buf,
                                int32_t   length);
 
+    // Get list of built-in chat templates
+    int32_t llama_chat_builtin_templates(const char ** output, size_t len);
+
     //
     // Sampling API
     //
index 6e9ba97272287711018751fee4533921fe4838e3..6a6f4c2a5eb7eba39fa8403cf653e8aba7e2d8c7 100644 (file)
@@ -1549,6 +1549,67 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
     },
 };
 
+enum llm_chat_template {
+    LLM_CHAT_TEMPLATE_CHATML,
+    LLM_CHAT_TEMPLATE_LLAMA_2,
+    LLM_CHAT_TEMPLATE_LLAMA_2_SYS,
+    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS,
+    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP,
+    LLM_CHAT_TEMPLATE_MISTRAL_V1,
+    LLM_CHAT_TEMPLATE_MISTRAL_V3,
+    LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
+    LLM_CHAT_TEMPLATE_MISTRAL_V7,
+    LLM_CHAT_TEMPLATE_PHI_3,
+    LLM_CHAT_TEMPLATE_ZEPHYR,
+    LLM_CHAT_TEMPLATE_MONARCH,
+    LLM_CHAT_TEMPLATE_GEMMA,
+    LLM_CHAT_TEMPLATE_ORION,
+    LLM_CHAT_TEMPLATE_OPENCHAT,
+    LLM_CHAT_TEMPLATE_VICUNA,
+    LLM_CHAT_TEMPLATE_VICUNA_ORCA,
+    LLM_CHAT_TEMPLATE_DEEPSEEK,
+    LLM_CHAT_TEMPLATE_DEEPSEEK_2,
+    LLM_CHAT_TEMPLATE_COMMAND_R,
+    LLM_CHAT_TEMPLATE_LLAMA_3,
+    LLM_CHAT_TEMPLATE_CHATGML_3,
+    LLM_CHAT_TEMPLATE_CHATGML_4,
+    LLM_CHAT_TEMPLATE_MINICPM,
+    LLM_CHAT_TEMPLATE_EXAONE_3,
+    LLM_CHAT_TEMPLATE_RWKV_WORLD,
+    LLM_CHAT_TEMPLATE_GRANITE,
+    LLM_CHAT_TEMPLATE_UNKNOWN,
+};
+
+static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
+    { "chatml",            LLM_CHAT_TEMPLATE_CHATML            },
+    { "llama2",            LLM_CHAT_TEMPLATE_LLAMA_2           },
+    { "llama2-sys",        LLM_CHAT_TEMPLATE_LLAMA_2_SYS       },
+    { "llama2-sys-bos",    LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS   },
+    { "llama2-sys-strip",  LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
+    { "mistral-v1",        LLM_CHAT_TEMPLATE_MISTRAL_V1        },
+    { "mistral-v3",        LLM_CHAT_TEMPLATE_MISTRAL_V3        },
+    { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
+    { "mistral-v7",        LLM_CHAT_TEMPLATE_MISTRAL_V7        },
+    { "phi3",              LLM_CHAT_TEMPLATE_PHI_3             },
+    { "zephyr",            LLM_CHAT_TEMPLATE_ZEPHYR            },
+    { "monarch",           LLM_CHAT_TEMPLATE_MONARCH           },
+    { "gemma",             LLM_CHAT_TEMPLATE_GEMMA             },
+    { "orion",             LLM_CHAT_TEMPLATE_ORION             },
+    { "openchat",          LLM_CHAT_TEMPLATE_OPENCHAT          },
+    { "vicuna",            LLM_CHAT_TEMPLATE_VICUNA            },
+    { "vicuna-orca",       LLM_CHAT_TEMPLATE_VICUNA_ORCA       },
+    { "deepseek",          LLM_CHAT_TEMPLATE_DEEPSEEK          },
+    { "deepseek2",         LLM_CHAT_TEMPLATE_DEEPSEEK_2        },
+    { "command-r",         LLM_CHAT_TEMPLATE_COMMAND_R         },
+    { "llama3",            LLM_CHAT_TEMPLATE_LLAMA_3           },
+    { "chatglm3",          LLM_CHAT_TEMPLATE_CHATGML_3         },
+    { "chatglm4",          LLM_CHAT_TEMPLATE_CHATGML_4         },
+    { "minicpm",           LLM_CHAT_TEMPLATE_MINICPM           },
+    { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
+    { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
+    { "granite",           LLM_CHAT_TEMPLATE_GRANITE           },
+};
+
 static llm_arch llm_arch_from_string(const std::string & name) {
     for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
         if (kv.second == name) {
@@ -21843,18 +21904,109 @@ int32_t llama_detokenize(
 // chat templates
 //
 
+static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
+    if (LLM_CHAT_TEMPLATES.find(tmpl) != LLM_CHAT_TEMPLATES.end()) {
+        return LLM_CHAT_TEMPLATES.at(tmpl);
+    }
+    auto tmpl_contains = [&tmpl](const char * haystack) -> bool {
+        return tmpl.find(haystack) != std::string::npos;
+    };
+    if (tmpl_contains("<|im_start|>")) {
+        return LLM_CHAT_TEMPLATE_CHATML;
+    } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
+        if (tmpl_contains("[SYSTEM_PROMPT]")) {
+            return LLM_CHAT_TEMPLATE_MISTRAL_V7;
+        } else if (
+            // catches official 'v1' template
+            tmpl_contains("' [INST] ' + system_message")
+            // catches official 'v3' and 'v3-tekken' templates
+            || tmpl_contains("[AVAILABLE_TOOLS]")
+        ) {
+            // Official mistral 'v1', 'v3' and 'v3-tekken' templates
+            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
+            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
+            if (tmpl_contains(" [INST]")) {
+                return LLM_CHAT_TEMPLATE_MISTRAL_V1;
+            } else if (tmpl_contains("\"[INST]\"")) {
+                return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN;
+            }
+            return LLM_CHAT_TEMPLATE_MISTRAL_V3;
+        } else {
+            // llama2 template and its variants
+            // [variant] support system message
+            // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
+            bool support_system_message = tmpl_contains("<<SYS>>");
+            bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
+            bool strip_message = tmpl_contains("content.strip()");
+            if (strip_message) {
+                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
+            } else if (add_bos_inside_history) {
+                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
+            } else if (support_system_message) {
+                return LLM_CHAT_TEMPLATE_LLAMA_2_SYS;
+            } else {
+                return LLM_CHAT_TEMPLATE_LLAMA_2;
+            }
+        }
+    } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
+        return LLM_CHAT_TEMPLATE_PHI_3;
+    } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
+        return LLM_CHAT_TEMPLATE_ZEPHYR;
+    } else if (tmpl_contains("bos_token + message['role']")) {
+        return LLM_CHAT_TEMPLATE_MONARCH;
+    } else if (tmpl_contains("<start_of_turn>")) {
+        return LLM_CHAT_TEMPLATE_GEMMA;
+    } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
+        // OrionStarAI/Orion-14B-Chat
+        return LLM_CHAT_TEMPLATE_ORION;
+    } else if (tmpl_contains("GPT4 Correct ")) {
+        // openchat/openchat-3.5-0106
+        return LLM_CHAT_TEMPLATE_OPENCHAT;
+    } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) {
+        // eachadea/vicuna-13b-1.1 (and Orca variant)
+        if (tmpl_contains("SYSTEM: ")) {
+            return LLM_CHAT_TEMPLATE_VICUNA_ORCA;
+        }
+        return LLM_CHAT_TEMPLATE_VICUNA;
+    } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) {
+        // deepseek-ai/deepseek-coder-33b-instruct
+        return LLM_CHAT_TEMPLATE_DEEPSEEK;
+    } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) {
+        // CohereForAI/c4ai-command-r-plus
+        return LLM_CHAT_TEMPLATE_COMMAND_R;
+    } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
+        return LLM_CHAT_TEMPLATE_LLAMA_3;
+    } else if (tmpl_contains("[gMASK]sop")) {
+        // chatglm3-6b
+        return LLM_CHAT_TEMPLATE_CHATGML_3;
+    } else if (tmpl_contains("[gMASK]<sop>")) {
+        return LLM_CHAT_TEMPLATE_CHATGML_4;
+    } else if (tmpl_contains(LU8("<用户>"))) {
+        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+        return LLM_CHAT_TEMPLATE_MINICPM;
+    } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
+        return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
+    } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
+        // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
+        // EXAONE-3.0-7.8B-Instruct
+        return LLM_CHAT_TEMPLATE_EXAONE_3;
+    } else if (tmpl_contains("rwkv-world")) {
+        return LLM_CHAT_TEMPLATE_RWKV_WORLD;
+    } else if (tmpl_contains("<|start_of_role|>")) {
+        return LLM_CHAT_TEMPLATE_GRANITE;
+    }
+    return LLM_CHAT_TEMPLATE_UNKNOWN;
+}
+
 // Simple version of "llama_apply_chat_template" that only works with strings
 // This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
 static int32_t llama_chat_apply_template_internal(
-    const std::string & tmpl,
+    const llm_chat_template tmpl,
     const std::vector<const llama_chat_message *> & chat,
     std::string & dest, bool add_ass) {
     // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
     std::stringstream ss;
-    auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
-        return tmpl.find(haystack) != std::string::npos;
-    };
-    if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
+    if (tmpl == LLM_CHAT_TEMPLATE_CHATML) {
         // chatml template
         for (auto message : chat) {
             ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@@ -21862,86 +22014,84 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl == "llama2" || tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
-        if (tmpl == "mistral-v7" || tmpl_contains("[SYSTEM_PROMPT]")) {
-            // Official mistral 'v7' template
-            // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
-            for (auto message : chat) {
-                std::string role(message->role);
-                std::string content(message->content);
-                if (role == "system") {
-                    ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
-                } else if (role == "user") {
-                    ss << "[INST] " << content << "[/INST]";
-                }
-                else {
-                    ss << " " << content << "</s>";
-                }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
+        // Official mistral 'v7' template
+        // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
+        for (auto message : chat) {
+            std::string role(message->role);
+            std::string content(message->content);
+            if (role == "system") {
+                ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
+            } else if (role == "user") {
+                ss << "[INST] " << content << "[/INST]";
             }
-        } else if (tmpl == "mistral-v1" || tmpl == "mistral-v3" || tmpl == "mistral-v3-tekken"
-                   || tmpl_contains("' [INST] ' + system_message") // catches official 'v1' template
-                   || tmpl_contains("[AVAILABLE_TOOLS]")) {        // catches official 'v3' and 'v3-tekken' templates
-            // Official mistral 'v1', 'v3' and 'v3-tekken' templates
-            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
-            // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
-            std::string leading_space = (tmpl == "mistral-v1" || tmpl_contains(" [INST]") ? " " : "");
-            std::string trailing_space = (tmpl == "mistral-v3-tekken" || tmpl_contains("\"[INST]\"") ? "" : " ");
-            bool trim_assistant_message = tmpl_contains("|trim + eos_token");
-            bool is_inside_turn = false;
-            for (auto message : chat) {
-                if (!is_inside_turn) {
-                    ss << leading_space << "[INST]" << trailing_space;
-                    is_inside_turn = true;
-                }
-                std::string role(message->role);
-                std::string content(message->content);
-                if (role == "system") {
-                    ss << content << "\n\n";
-                } else if (role == "user") {
-                    ss << content << leading_space << "[/INST]";
-                } else {
-                    ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "</s>";
-                    is_inside_turn = false;
-                }
+            else {
+                ss << " " << content << "</s>";
             }
-        } else {
-            // llama2 template and its variants
-            // [variant] support system message
-            // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
-            bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "llama2";
-            // [variant] space before + after response
-            bool space_around_response = tmpl_contains("' ' + eos_token");
-            // [variant] add BOS inside history
-            bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
-            // [variant] trim spaces from the input message
-            bool strip_message = tmpl_contains("content.strip()");
-            // construct the prompt
-            bool is_inside_turn = true; // skip BOS at the beginning
-            ss << "[INST] ";
-            for (auto message : chat) {
-                std::string content = strip_message ? trim(message->content) : message->content;
-                std::string role(message->role);
-                if (!is_inside_turn) {
-                    is_inside_turn = true;
-                    ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
-                }
-                if (role == "system") {
-                    if (support_system_message) {
-                        ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
-                    } else {
-                        // if the model does not support system message, we still include it in the first message, but without <<SYS>>
-                        ss << content << "\n";
-                    }
-                } else if (role == "user") {
-                    ss << content << " [/INST]";
+        }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
+            || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3
+            || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) {
+        // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
+        // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
+        std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : "";
+        std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " ";
+        bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3;
+        bool is_inside_turn = false;
+        for (auto message : chat) {
+            if (!is_inside_turn) {
+                ss << leading_space << "[INST]" << trailing_space;
+                is_inside_turn = true;
+            }
+            std::string role(message->role);
+            std::string content(message->content);
+            if (role == "system") {
+                ss << content << "\n\n";
+            } else if (role == "user") {
+                ss << content << leading_space << "[/INST]";
+            } else {
+                ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "</s>";
+                is_inside_turn = false;
+            }
+        }
+    } else if (
+            tmpl == LLM_CHAT_TEMPLATE_LLAMA_2
+            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS
+            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS
+            || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) {
+        // llama2 template and its variants
+        // [variant] support system message
+        // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
+        bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2;
+        // [variant] add BOS inside history
+        bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
+        // [variant] trim spaces from the input message
+        bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
+        // construct the prompt
+        bool is_inside_turn = true; // skip BOS at the beginning
+        ss << "[INST] ";
+        for (auto message : chat) {
+            std::string content = strip_message ? trim(message->content) : message->content;
+            std::string role(message->role);
+            if (!is_inside_turn) {
+                is_inside_turn = true;
+                ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
+            }
+            if (role == "system") {
+                if (support_system_message) {
+                    ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
                 } else {
-                    ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
-                    is_inside_turn = false;
+                    // if the model does not support system message, we still include it in the first message, but without <<SYS>>
+                    ss << content << "\n";
                 }
+            } else if (role == "user") {
+                ss << content << " [/INST]";
+            } else {
+                ss << content << "</s>";
+                is_inside_turn = false;
             }
-            // llama2 templates seem to not care about "add_generation_prompt
         }
-    } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) {
         // Phi 3
         for (auto message : chat) {
             std::string role(message->role);
@@ -21950,7 +22100,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
         // zephyr template
         for (auto message : chat) {
             ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@@ -21958,7 +22108,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) {
         // mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
         for (auto message : chat) {
             std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@@ -21967,7 +22117,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<s>assistant\n";
         }
-    } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) {
         // google/gemma-7b-it
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -21989,7 +22139,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<start_of_turn>model\n";
         }
-    } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) {
         // OrionStarAI/Orion-14B-Chat
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -22009,7 +22159,7 @@ static int32_t llama_chat_apply_template_internal(
                 ss << message->content << "</s>";
             }
         }
-    } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) {
         // openchat/openchat-3.5-0106,
         for (auto message : chat) {
             std::string role(message->role);
@@ -22023,13 +22173,13 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "GPT4 Correct Assistant:";
         }
-    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
         // eachadea/vicuna-13b-1.1 (and Orca variant)
         for (auto message : chat) {
             std::string role(message->role);
             if (role == "system") {
                 // Orca-Vicuna variant uses a system prefix
-                if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
+                if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
                     ss << "SYSTEM: " << message->content << "\n";
                 } else {
                     ss << message->content << "\n\n";
@@ -22043,7 +22193,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "ASSISTANT:";
         }
-    } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) {
         // deepseek-ai/deepseek-coder-33b-instruct
         for (auto message : chat) {
             std::string role(message->role);
@@ -22058,7 +22208,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "### Response:\n";
         }
-    } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) {
         // CohereForAI/c4ai-command-r-plus
         for (auto message : chat) {
             std::string role(message->role);
@@ -22073,7 +22223,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
         }
-    } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) {
         // Llama 3
         for (auto message : chat) {
             std::string role(message->role);
@@ -22082,7 +22232,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
         }
-    } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
         // chatglm3-6b
         ss << "[gMASK]" << "sop";
         for (auto message : chat) {
@@ -22092,7 +22242,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]<sop>")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
         ss << "[gMASK]" << "<sop>";
         for (auto message : chat) {
             std::string role(message->role);
@@ -22101,7 +22251,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         for (auto message : chat) {
             std::string role(message->role);
@@ -22113,7 +22263,7 @@ static int32_t llama_chat_apply_template_internal(
                 ss << trim(message->content);
             }
         }
-    } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) {
         // DeepSeek-V2
         for (auto message : chat) {
             std::string role(message->role);
@@ -22128,7 +22278,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "Assistant:";
         }
-    } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
         // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
         // EXAONE-3.0-7.8B-Instruct
         for (auto message : chat) {
@@ -22144,7 +22294,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "[|assistant|]";
         }
-    } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
         // this template requires the model to have "\n\n" as EOT token
         for (auto message : chat) {
             std::string role(message->role);
@@ -22154,7 +22304,7 @@ static int32_t llama_chat_apply_template_internal(
                 ss << message->content << "\n\n";
             }
         }
-    } else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
         // IBM Granite template
         for (const auto & message : chat) {
             std::string role(message->role);
@@ -22206,7 +22356,11 @@ int32_t llama_chat_apply_template(
     }
 
     std::string formatted_chat;
-    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
+    llm_chat_template detected_tmpl = llama_chat_detect_template(curr_tmpl);
+    if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) {
+        return -1;
+    }
+    int32_t res = llama_chat_apply_template_internal(detected_tmpl, chat_vec, formatted_chat, add_ass);
     if (res < 0) {
         return res;
     }
@@ -22216,6 +22370,15 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
+int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
+    auto it = LLM_CHAT_TEMPLATES.begin();
+    for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) {
+        output[i] = it->first.c_str();
+        std::advance(it, 1);
+    }
+    return (int32_t) LLM_CHAT_TEMPLATES.size();
+}
+
 //
 // sampling
 //
index dd8f7d5f096dcb015fbb5c6a4f36cd560e363b10..aa140b5696f743a9bdd7af1ca5eb3d18cf0e483e 100644 (file)
@@ -82,9 +82,9 @@ int main(void) {
         // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
         "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
         // TheBloke/FusionNet_34Bx2_MoE-AWQ
-        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST]    I am an assistant    </s><s>[INST] Another question [/INST]",
+        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST]   I am an assistant   </s><s>[INST] Another question [/INST]",
         // bofenghuang/vigogne-2-70b-chat
-        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
+        "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
         // mlabonne/AlphaMonarch-7B
         "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
         // google/gemma-7b-it
@@ -133,6 +133,17 @@ int main(void) {
     std::vector<char> formatted_chat(1024);
     int32_t res;
 
+    // list all supported templates
+    std::vector<const char *> supported_tmpl;
+    res = llama_chat_builtin_templates(nullptr, 0);
+    assert(res > 0);
+    supported_tmpl.resize(res);
+    res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
+    printf("Built-in chat templates:\n");
+    for (auto tmpl : supported_tmpl) {
+        printf("  %s\n", tmpl);
+    }
+
     // test invalid chat template
     res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
     assert(res < 0);
@@ -174,7 +185,8 @@ int main(void) {
     assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
     assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
     assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
-    assert(fmt_sys("llama2") == "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\n");
+    assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+    assert(fmt_sys("llama2-sys") == "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\n");
     assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
     assert(fmt_sys("gemma")  == ""); // for gemma, system message is merged with user message
     assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
@@ -203,5 +215,7 @@ int main(void) {
     assert(fmt_single("gemma")  == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
 
+    printf("Test chat templates: OK\n");
+
     return 0;
 }