]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : make reverse prompt option act as a stop token in non-interactive mode (#1032)
authorJason McCartney <redacted>
Fri, 19 May 2023 17:24:59 +0000 (10:24 -0700)
committerGitHub <redacted>
Fri, 19 May 2023 17:24:59 +0000 (20:24 +0300)
* Make reverse prompt option act as a stop token in non-interactive scenarios

* Making requested review changes

* Update gpt_params_parse and fix a merge error

* Revert "Update gpt_params_parse and fix a merge error"

This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8.

* Update gpt_params_parse and fix a merge error take 2

examples/common.cpp
examples/main/main.cpp

index a4fea4af48076dc1d3b203f94c7a8d44ccbb1f28..e89df537eca49f788d6a770ad8de285c8f6b852c 100644 (file)
@@ -351,7 +351,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     }
     if (params.prompt_cache_all &&
             (params.interactive || params.interactive_first ||
-             params.instruct || params.antiprompt.size())) {
+             params.instruct)) {
         fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
         gpt_print_usage(argc, argv, default_params);
         exit(1);
@@ -373,8 +373,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
     fprintf(stderr, "  --multiline-input     allows you to write or paste multiple lines without ending each in '\\'\n");
     fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n");
-    fprintf(stderr, "                        run in interactive mode and poll user input upon seeing PROMPT (can be\n");
-    fprintf(stderr, "                        specified more than once for multiple prompts).\n");
+    fprintf(stderr, "                        halt generation at PROMPT, return control in interactive mode\n");
+    fprintf(stderr, "                        (can be specified more than once for multiple prompts).\n");
     fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n");
     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1, use random seed for < 0)\n");
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
index 18673ed2e00f2926b961d199114a5f2e2b24035a..4d886f8defe568ac21d8d3f5278abd92e56d2530 100644 (file)
@@ -208,8 +208,8 @@ int main(int argc, char ** argv) {
         params.antiprompt.push_back("### Instruction:\n\n");
     }
 
-    // enable interactive mode if reverse prompt or interactive start is specified
-    if (params.antiprompt.size() != 0 || params.interactive_first) {
+    // enable interactive mode if interactive start is specified
+    if (params.interactive_first) {
         params.interactive = true;
     }
 
@@ -305,7 +305,7 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd;
 
-    while (n_remain != 0 || params.interactive) {
+    while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict
         if (embd.size() > 0) {
             // infinite text generation via context swapping
@@ -503,9 +503,8 @@ int main(int argc, char ** argv) {
             console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
         }
 
-        // in interactive mode, and not currently processing queued inputs;
-        // check if we should prompt the user for more
-        if (params.interactive && (int) embd_inp.size() <= n_consumed) {
+        // if not currently processing queued inputs;
+        if ((int) embd_inp.size() <= n_consumed) {
 
             // check for reverse prompt
             if (params.antiprompt.size()) {
@@ -516,10 +515,21 @@ int main(int argc, char ** argv) {
 
                 is_antiprompt = false;
                 // Check if each of the reverse prompts appears at the end of the output.
+                // If we're not running interactively, the reverse prompt might be tokenized with some following characters
+                // so we'll compensate for that by widening the search window a bit.
                 for (std::string & antiprompt : params.antiprompt) {
-                    if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
-                        is_interacting = true;
+                    size_t extra_padding = params.interactive ? 0 : 2;
+                    size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
+                        ? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
+                        : 0;
+
+                    if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                            console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
+                        }
                         is_antiprompt = true;
+                        fflush(stdout);
                         break;
                     }
                 }