]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : Add ChatML functionality to main example (#4046)
authorSeb C <redacted>
Mon, 20 Nov 2023 13:56:59 +0000 (00:26 +1030)
committerGitHub <redacted>
Mon, 20 Nov 2023 13:56:59 +0000 (14:56 +0100)
Co-authored-by: Sebastian Cramond <redacted>
common/common.cpp
common/common.h
examples/infill/infill.cpp
examples/main/main.cpp

index 3f10b5d7f3afe0327b6723cab28be27f7c5f542d..eec704b99f888ec7720e49aaa754678e768c93c3 100644 (file)
@@ -491,6 +491,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             params.interactive_first = true;
         } else if (arg == "-ins" || arg == "--instruct") {
             params.instruct = true;
+        } else if (arg == "-cml" || arg == "--chatml") {
+            params.chatml = true;
         } else if (arg == "--infill") {
             params.infill = true;
         } else if (arg == "--multiline-input") {
@@ -730,6 +732,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -i, --interactive     run in interactive mode\n");
     printf("  --interactive-first   run in interactive mode and wait for input right away\n");
     printf("  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
+    printf("  -cml, --chatml        run in chatml mode (use with ChatML-compatible models)\n");
     printf("  --multiline-input     allows you to write or paste multiple lines without ending each in '\\'\n");
     printf("  -r PROMPT, --reverse-prompt PROMPT\n");
     printf("                        halt generation at PROMPT, return control in interactive mode\n");
index cc048daab56730f053a6b7efed7d00338cf93420..88fa13fc067c2b1553c532ff5d05a6d163829106 100644 (file)
@@ -102,6 +102,7 @@ struct gpt_params {
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
+    bool chatml            = false; // chatml mode (used for models trained on chatml syntax)
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
 
index 11f7410edd6f84755029dac9ac6ff09e0d3b6ffd..4a7827876e2151a0ceba42f95e7319aa2668c746 100644 (file)
@@ -146,6 +146,13 @@ int main(int argc, char ** argv) {
 
         return 0;
     }
+    if (params.chatml) {
+        printf("\n************\n");
+        printf("%s: please use the 'main' tool for chatml mode\n", __func__);
+        printf("************\n\n");
+
+        return 0;
+    }
     if (!params.antiprompt.empty()) {
         printf("\n************\n");
         printf("%s: please use the 'main' tool for antiprompt mode\n", __func__);
index 99d219d6571d0de95c07b91ea16246980a613732..31ec8cade19be1fc1d6e0eb85b3031e4d6ba1032 100644 (file)
@@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd_inp;
 
-    if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
+    if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
         LOG("tokenize the prompt\n");
+        if (params.chatml) {
+            params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
+        }
         embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
     } else {
         LOG("use session tokens\n");
@@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
     }
 
     // number of tokens to keep when resetting context
-    if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
+    if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
         params.n_keep = (int)embd_inp.size();
     }
 
@@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
     LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
     LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
 
+    // chatml prefix & suffix
+    const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
+    const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
+
+    LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
+    LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str());
+
     // in instruct mode, we inject a prefix and a suffix to each input by the user
     if (params.instruct) {
         params.interactive_first = true;
         params.antiprompt.push_back("### Instruction:\n\n");
     }
+    // similar for chatml mode
+    else if (params.chatml) {
+        params.interactive_first = true;
+        params.antiprompt.push_back("<|im_start|>user\n");
+    }
 
     // enable interactive mode if interactive start is specified
     if (params.interactive_first) {
@@ -705,7 +720,7 @@ int main(int argc, char ** argv) {
 
                     is_interacting = true;
                     printf("\n");
-                } else if (params.instruct) {
+                } else if (params.instruct || params.chatml) {
                     is_interacting = true;
                 }
             }
@@ -713,7 +728,7 @@ int main(int argc, char ** argv) {
             if (n_past > 0 && is_interacting) {
                 LOG("waiting for user input\n");
 
-                if (params.instruct) {
+                if (params.instruct || params.chatml) {
                     printf("\n> ");
                 }
 
@@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
                         n_consumed = embd_inp.size();
                         embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
                     }
+                    // chatml mode: insert user chat prefix
+                    if (params.chatml && !is_antiprompt) {
+                        LOG("inserting chatml prefix\n");
+                        n_consumed = embd_inp.size();
+                        embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
+                    }
                     if (params.escape) {
                         process_escapes(buffer);
                     }
@@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
                         LOG("inserting instruction suffix\n");
                         embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
                     }
+                    // chatml mode: insert assistant chat suffix
+                    if (params.chatml) {
+                        LOG("inserting chatml suffix\n");
+                        embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
+                    }
 
                     for (size_t i = original_size; i < embd_inp.size(); ++i) {
                         const llama_token token = embd_inp[i];
@@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
         }
 
         // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
+        if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
             LOG_TEE(" [end of text]\n");
             break;
         }