]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS (#2304)
authorXiao-Yong Jin <redacted>
Tue, 25 Jul 2023 12:19:11 +0000 (07:19 -0500)
committerGitHub <redacted>
Tue, 25 Jul 2023 12:19:11 +0000 (15:19 +0300)
* add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS

The BOS precedes the string specified by `--in-prefix`.
Model generated EOS is now kept in the context.

It provides a way to strictly following the prompt format used in
Llama-2-chat.

The EOS handling also benefits some existing finetunes that uses
EOS to mark the end of turn.

* examples/common: move input_prefix_bos to other bools

examples/common.cpp
examples/common.h
examples/main/main.cpp

index 0e88a128ad1956e609d5d4a38bd70eadb00a405e..dd964c8a7481a32808781f8b10d96d8c16c124d8 100644 (file)
@@ -432,6 +432,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             exit(0);
         } else if (arg == "--random-prompt") {
             params.random_prompt = true;
+        } else if (arg == "--in-prefix-bos") {
+            params.input_prefix_bos = true;
         } else if (arg == "--in-prefix") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -517,6 +519,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "                        not supported with --interactive or other interactive options\n");
     fprintf(stdout, "  --prompt-cache-ro     if specified, uses the prompt cache but does not update it.\n");
     fprintf(stdout, "  --random-prompt       start with a randomized prompt.\n");
+    fprintf(stdout, "  --in-prefix-bos       prefix BOS to user inputs, preceding the `--in-prefix` string\n");
     fprintf(stdout, "  --in-prefix STRING    string to prefix user inputs with (default: empty)\n");
     fprintf(stdout, "  --in-suffix STRING    string to suffix after user inputs with (default: empty)\n");
     fprintf(stdout, "  -f FNAME, --file FNAME\n");
index 894a0850a68c6de1bf8de5d4a55b3d70a599cdc7..2d87c923b43f05866a215c503513e5ecef095f92 100644 (file)
@@ -82,6 +82,7 @@ struct gpt_params {
     bool interactive_first = false; // wait for user input immediately
     bool multiline_input   = false; // reverse the usage of `\`
 
+    bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool instruct          = false; // instruction mode (used for Alpaca models)
     bool penalize_nl       = true;  // consider newlines as a repeatable token
     bool perplexity        = false; // compute perplexity over the prompt
index 16ddc22747f6b549530c905d3b9e7fae1b6b09c3..3796a9230113653402911ffbde6e0d21d44f4ce9 100644 (file)
@@ -325,6 +325,10 @@ int main(int argc, char ** argv) {
             }
         }
 
+        if (params.input_prefix_bos) {
+            fprintf(stderr, "Input prefix with BOS\n");
+        }
+
         if (!params.input_prefix.empty()) {
             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
         }
@@ -633,16 +637,6 @@ int main(int argc, char ** argv) {
                 last_n_tokens.push_back(id);
             }
 
-            // replace end of text token with newline token when in interactive mode
-            if (id == llama_token_eos() && params.interactive && !params.instruct) {
-                id = llama_token_newline.front();
-                if (params.antiprompt.size() != 0) {
-                    // tokenize and inject first reverse prompt
-                    const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
-                    embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
-                }
-            }
-
             // add it to the context
             embd.push_back(id);
 
@@ -708,11 +702,34 @@ int main(int argc, char ** argv) {
                 }
             }
 
+            // deal with end of text token in interactive mode
+            if (last_n_tokens.back() == llama_token_eos()) {
+                if (params.interactive) {
+                    if (params.antiprompt.size() != 0) {
+                        // tokenize and inject first reverse prompt
+                        const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
+                        embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
+                        is_antiprompt = true;
+                    }
+
+                    is_interacting = true;
+                    printf("\n");
+                    console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
+                    fflush(stdout);
+                } else if (params.instruct) {
+                    is_interacting = true;
+                }
+            }
+
             if (n_past > 0 && is_interacting) {
                 if (params.instruct) {
                     printf("\n> ");
                 }
 
+                if (params.input_prefix_bos) {
+                    embd_inp.push_back(llama_token_bos());
+                }
+
                 std::string buffer;
                 if (!params.input_prefix.empty()) {
                     buffer += params.input_prefix;
@@ -776,13 +793,9 @@ int main(int argc, char ** argv) {
         }
 
         // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos()) {
-            if (params.instruct) {
-                is_interacting = true;
-            } else {
-                fprintf(stderr, " [end of text]\n");
-                break;
-            }
+        if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
+            fprintf(stderr, " [end of text]\n");
+            break;
         }
 
         // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.