]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : alternative instruct mode (Vicuna support, etc.) (#863)
authorTomáš Pazdiora <redacted>
Fri, 14 Apr 2023 15:19:17 +0000 (17:19 +0200)
committerGitHub <redacted>
Fri, 14 Apr 2023 15:19:17 +0000 (18:19 +0300)
* Add support for configs, add configurable prefixes / suffixes, deprecate instruct mode, add stop prompt

* Add multiline mode, update text input.

* bugfix

* update implementation

* typos

* Change --multiline implementation to be toggled by EOF.

* bugfix

* default multiline mode

* add more configs

* update formating

* update formatting

* apply suggestions

12 files changed:
configs/alpaca-native-enhanced.txt [new file with mode: 0644]
configs/alpaca.txt [new file with mode: 0644]
configs/chat-with-bob.txt [new file with mode: 0644]
configs/llama.txt [new file with mode: 0644]
configs/vicuna-simple.txt [new file with mode: 0644]
configs/vicuna-stop.txt [new file with mode: 0644]
configs/vicuna.txt [new file with mode: 0644]
examples/common.cpp
examples/common.h
examples/main/main.cpp
prompts/alpaca.txt [deleted file]
prompts/chat-with-bob.txt [deleted file]

diff --git a/configs/alpaca-native-enhanced.txt b/configs/alpaca-native-enhanced.txt
new file mode 100644 (file)
index 0000000..109d315
--- /dev/null
@@ -0,0 +1,21 @@
+--ctx_size 2048
+--batch_size 16
+--repeat_penalty 1.15
+--temp 0.4
+--top_k 30
+--top_p 0.18
+
+--interactive-first
+--keep -1
+
+--ins-prefix-bos
+--ins-prefix "\n\nUser: "
+--ins-suffix "\n\nAssistant: "
+--reverse-prompt "User: "
+
+-p "You are an AI language model designed to assist the User by answering their questions, offering advice, and engaging in casual conversation in a friendly, helpful, and informative manner. You respond clearly, coherently, and you consider the conversation history.
+
+User: Hey, how's it going?
+
+Assistant: Hey there! I'm doing great, thank you. What can I help you with today? Let's have a fun chat!"
+
diff --git a/configs/alpaca.txt b/configs/alpaca.txt
new file mode 100644 (file)
index 0000000..99a3ab4
--- /dev/null
@@ -0,0 +1,9 @@
+--clean-interface
+--interactive-first
+--keep -1
+--ins-prefix-bos
+--ins-prefix "\n\n### Instruction:\n\n"
+--ins-suffix "\n\n### Response:\n\n"
+--reverse-prompt "### Instruction:\n\n"
+
+-p "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
diff --git a/configs/chat-with-bob.txt b/configs/chat-with-bob.txt
new file mode 100644 (file)
index 0000000..0caa749
--- /dev/null
@@ -0,0 +1,15 @@
+--interactive-first
+--keep -1
+--ins-prefix-bos
+--ins-prefix "\nUser: "
+--ins-suffix "\nBob: "
+--reverse-prompt "User: "
+--rm-trailing-space-workaround
+
+-p "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
+
+User: Hello, Bob.
+Bob: Hello. How may I help you today?
+User: Please tell me the largest city in Europe.
+Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."
+
diff --git a/configs/llama.txt b/configs/llama.txt
new file mode 100644 (file)
index 0000000..9d23e75
--- /dev/null
@@ -0,0 +1,3 @@
+--interactive-first
+--keep -1
+--temp 0.1
diff --git a/configs/vicuna-simple.txt b/configs/vicuna-simple.txt
new file mode 100644 (file)
index 0000000..efa60d9
--- /dev/null
@@ -0,0 +1,7 @@
+--interactive-first
+--keep -1
+--ins-prefix-bos
+--ins-prefix "\n### Human: "
+--ins-suffix "\n### Assistant: "
+--reverse-prompt "### Human: "
+--rm-trailing-space-workaround
diff --git a/configs/vicuna-stop.txt b/configs/vicuna-stop.txt
new file mode 100644 (file)
index 0000000..911d067
--- /dev/null
@@ -0,0 +1,8 @@
+--interactive-first
+--keep -1
+--ins-prefix-bos
+--ins-prefix "\n### Human: "
+--ins-suffix "\n### Assistant: "
+--reverse-prompt "### Human: "
+--stop-prompt "### Assistant: "
+--rm-trailing-space-workaround
diff --git a/configs/vicuna.txt b/configs/vicuna.txt
new file mode 100644 (file)
index 0000000..6d81141
--- /dev/null
@@ -0,0 +1,9 @@
+--interactive-first
+--keep -1
+--ins-prefix-bos
+--ins-prefix "\n### Human: "
+--ins-suffix "\n### Assistant: "
+--reverse-prompt "### Human: "
+--rm-trailing-space-workaround
+
+-p "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
index 0772dbfe142ffe167f2c3e5ffa3815f6488af04e..eaa5aceea8eb0ef9ce0155b49a08b71d3e4a0104 100644 (file)
@@ -2,10 +2,13 @@
 
 #include <cassert>
 #include <cstring>
+#include <iostream>
 #include <fstream>
+#include <sstream>
 #include <string>
 #include <iterator>
 #include <algorithm>
+#include <regex>
 
 #if defined (_WIN32)
 #include <fcntl.h>
@@ -23,6 +26,43 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
 #define CP_UTF8 65001
 #endif
 
+void split_args(const std::string & args_string, std::vector<std::string> & output_args)
+{
+    std::string current_arg = "";
+    bool in_quotes = false;
+    char quote_type;
+
+    for (char c : args_string) {
+        if (c == '"' || c == '\'') {
+            if (!in_quotes) {
+                in_quotes = true;
+                quote_type = c;
+            } else if (quote_type == c) {
+                in_quotes = false;
+            } else {
+                current_arg += c;
+            }
+        } else if (in_quotes) {
+            current_arg += c;
+        } else if (std::isspace(c)) {
+            if (current_arg != "") {
+                output_args.push_back(current_arg);
+                current_arg = "";
+            }
+        } else {
+            current_arg += c;
+        }
+    }
+
+    if (current_arg != "") {
+        output_args.push_back(current_arg);
+    }
+}
+
+std::string unescape(const std::string & str) {
+    return std::regex_replace(str, std::regex("\\\\n"), "\n");
+}
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     // determine sensible default number of threads.
     // std::thread::hardware_concurrency may not be equal to the number of cores, or may return 0.
@@ -40,35 +80,66 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     std::string arg;
     gpt_params default_params;
 
+    // get additional arguments from config files
+    std::vector<std::string> args;
     for (int i = 1; i < argc; i++) {
         arg = argv[i];
+        if (arg == "--config") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::ifstream file(argv[i]);
+            if (!file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            std::string args_string;
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(args_string));
+            if (args_string.back() == '\n') {
+                args_string.pop_back();
+            }
+            split_args(args_string, args);
+            for (int j = 0; j < args.size(); j++) {
+                args[j] = unescape(args[j]);
+            }
+        } else {
+            args.emplace_back(argv[i]);
+        }
+    }
+
+    // parse args
+    int args_c = static_cast<int>(args.size());
+    for (int i = 0; i < args_c && !invalid_param; i++) {
+        arg = args[i];
 
         if (arg == "-s" || arg == "--seed") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.seed = std::stoi(argv[i]);
+            params.seed = std::stoi(args[i]);
         } else if (arg == "-t" || arg == "--threads") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_threads = std::stoi(argv[i]);
+            params.n_threads = std::stoi(args[i]);
         } else if (arg == "-p" || arg == "--prompt") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.prompt = argv[i];
+            params.prompt = args[i];
         } else if (arg == "-f" || arg == "--file") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            std::ifstream file(argv[i]);
+            std::ifstream file(args[i]);
             if (!file) {
-                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                fprintf(stderr, "error: failed to open file '%s'\n", args[i].c_str());
                 invalid_param = true;
                 break;
             }
@@ -77,80 +148,100 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 params.prompt.pop_back();
             }
         } else if (arg == "-n" || arg == "--n_predict") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_predict = std::stoi(argv[i]);
+            params.n_predict = std::stoi(args[i]);
         } else if (arg == "--top_k") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.top_k = std::stoi(argv[i]);
+            params.top_k = std::stoi(args[i]);
         } else if (arg == "-c" || arg == "--ctx_size") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_ctx = std::stoi(argv[i]);
+            params.n_ctx = std::stoi(args[i]);
         } else if (arg == "--memory_f32") {
             params.memory_f16 = false;
         } else if (arg == "--top_p") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.top_p = std::stof(argv[i]);
+            params.top_p = std::stof(args[i]);
         } else if (arg == "--temp") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.temp = std::stof(argv[i]);
+            params.temp = std::stof(args[i]);
         } else if (arg == "--repeat_last_n") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.repeat_last_n = std::stoi(argv[i]);
+            params.repeat_last_n = std::stoi(args[i]);
         } else if (arg == "--repeat_penalty") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.repeat_penalty = std::stof(argv[i]);
+            params.repeat_penalty = std::stof(args[i]);
         } else if (arg == "-b" || arg == "--batch_size") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_batch = std::stoi(argv[i]);
+            params.n_batch = std::stoi(args[i]);
             params.n_batch = std::min(512, params.n_batch);
         } else if (arg == "--keep") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_keep = std::stoi(argv[i]);
+            params.n_keep = std::stoi(args[i]);
         } else if (arg == "-m" || arg == "--model") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.model = argv[i];
+            params.model = args[i];
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
         } else if (arg == "--embedding") {
             params.embedding = true;
+        } else if (arg == "--clean-interface") {
+            params.clean_interface = true;
         } else if (arg == "--interactive-start") {
             params.interactive = true;
         } else if (arg == "--interactive-first") {
             params.interactive_start = true;
         } else if (arg == "-ins" || arg == "--instruct") {
-            params.instruct = true;
+            fprintf(stderr, "\n\nWarning: instruct mode is deprecated! Use: \n"
+                "--clean-interface "
+                "--interactive-first "
+                "--keep -1 "
+                "--ins-prefix-bos "
+                "--ins-prefix \"\\n\\n### Instruction:\\n\\n\" "
+                "--ins-suffix \"\\n\\n### Response:\\n\\n\" "
+                "-r \"### Instruction:\\n\\n\" "
+            "\n\n");
+            // params.instruct = true;
+            params.clean_interface = true;
+            params.interactive_start = true;
+            params.n_keep = -1;
+            params.instruct_prefix_bos = true;
+            params.instruct_prefix = "\n\n### Instruction:\n\n";
+            params.instruct_suffix = "\n\n### Response:\n\n";
+            params.antiprompt.push_back("### Instruction:\n\n");
         } else if (arg == "--color") {
             params.use_color = true;
+        } else if (arg == "--disable-multiline") {
+            params.multiline_mode = false;
         } else if (arg == "--mlock") {
             params.use_mlock = true;
         } else if (arg == "--no-mmap") {
@@ -160,65 +251,94 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "--verbose-prompt") {
             params.verbose_prompt = true;
         } else if (arg == "-r" || arg == "--reverse-prompt") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.antiprompt.push_back(argv[i]);
+            params.antiprompt.push_back(args[i]);
+        } else if (arg == "--stop-prompt") {
+            if (++i >= args_c) {
+                invalid_param = true;
+                break;
+            }
+            params.stopprompt.push_back(args[i]);
+        } else if (arg == "--rm-trailing-space-workaround") {
+            params.rm_trailing_space_workaround = true;
         } else if (arg == "--perplexity") {
             params.perplexity = true;
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
         } else if (arg == "--n_parts") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.n_parts = std::stoi(argv[i]);
+            params.n_parts = std::stoi(args[i]);
         } else if (arg == "-h" || arg == "--help") {
-            gpt_print_usage(argc, argv, default_params);
+            gpt_print_usage(argv[0], default_params);
             exit(0);
         } else if (arg == "--random-prompt") {
             params.random_prompt = true;
         } else if (arg == "--in-prefix") {
-            if (++i >= argc) {
+            if (++i >= args_c) {
+                invalid_param = true;
+                break;
+            }
+            params.input_prefix = args[i];
+        } else if (arg == "--ins-prefix-bos") {
+            params.instruct_prefix_bos = true;
+        } else if (arg == "--ins-prefix") {
+            if (++i >= args_c) {
                 invalid_param = true;
                 break;
             }
-            params.input_prefix = argv[i];
+            params.instruct_prefix = args[i];
+        } else if (arg == "--ins-suffix-bos") {
+            params.instruct_suffix_bos = true;
+        } else if (arg == "--ins-suffix") {
+            if (++i >= args_c) {
+                invalid_param = true;
+                break;
+            }
+            params.instruct_suffix = args[i];
         } else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
-            gpt_print_usage(argc, argv, default_params);
+            gpt_print_usage(argv[0], default_params);
             exit(1);
         }
     }
     if (invalid_param) {
         fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
-        gpt_print_usage(argc, argv, default_params);
+        gpt_print_usage(argv[0], default_params);
         exit(1);
     }
 
     return true;
 }
 
-void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
-    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+void gpt_print_usage(char * argv_0, const gpt_params & params) {
+    fprintf(stderr, "usage: %s [options]\n", argv_0);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
     fprintf(stderr, "  -i, --interactive     run in interactive mode\n");
     fprintf(stderr, "  --interactive-first   run in interactive mode and wait for input right away\n");
-    fprintf(stderr, "  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
+    fprintf(stderr, "  --clean-interface     hides input prefix & suffix and displays '>' instead\n");
     fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n");
     fprintf(stderr, "                        run in interactive mode and poll user input upon seeing PROMPT (can be\n");
     fprintf(stderr, "                        specified more than once for multiple prompts).\n");
     fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n");
+    fprintf(stderr, "  --disable-multiline   disable multiline mode (use Ctrl+D on Linux/Mac and Ctrl+Z then Return on Windows to toggle multiline)\n");
     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1, use random seed for <= 0)\n");
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "                        prompt to start generation with (default: empty)\n");
     fprintf(stderr, "  --random-prompt       start with a randomized prompt.\n");
     fprintf(stderr, "  --in-prefix STRING    string to prefix user inputs with (default: empty)\n");
+    fprintf(stderr, "  --ins-prefix STRING   (instruct) prefix user inputs with tokenized string (default: empty)\n");
+    fprintf(stderr, "  --ins-prefix-bos      (instruct) prepend bos token to instruct prefix.\n");
+    fprintf(stderr, "  --ins-suffix STRING   (instruct) suffix user inputs with tokenized string (default: empty)\n");
+    fprintf(stderr, "  --ins-suffix-bos      (instruct) prepend bos token to instruct suffix.\n");
     fprintf(stderr, "  -f FNAME, --file FNAME\n");
     fprintf(stderr, "                        prompt file to start generation.\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
@@ -328,3 +448,61 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
     str = strTo;
 }
 #endif
+
+bool get_input_text(std::string & input_text, bool eof_toggled_multiline_mode) {
+    bool another_line = true;
+    bool is_eof_multiline_toggled = false;
+    do {
+        std::string line;
+#if defined(_WIN32)
+        auto & stdcin = std::wcin;
+        std::wstring wline;
+        if (!std::getline(stdcin, wline)) {
+            // input stream is bad or EOF received
+            if (stdcin.bad()) {
+                fprintf(stderr, "%s: error: input stream bad\n", __func__);
+                return 1;
+            }
+        }
+        win32_utf8_encode(wline, line);
+#else
+        auto & stdcin = std::cin;
+        if (!std::getline(stdcin, line)) {
+            // input stream is bad or EOF received
+            if (stdcin.bad()) {
+                fprintf(stderr, "%s: error: input stream bad\n", __func__);
+                return 1;
+            }
+        }
+#endif
+        if (stdcin.eof()) {
+            stdcin.clear();
+            stdcin.seekg(0, std::ios::beg);
+            if (!eof_toggled_multiline_mode) {
+                another_line = false;
+            } else {
+                is_eof_multiline_toggled = !is_eof_multiline_toggled;
+                if (is_eof_multiline_toggled) {
+                    input_text += line;
+                    continue;
+                }
+            }
+        }
+        if (!eof_toggled_multiline_mode) {
+            if (line.empty() || line.back() != '\\') {
+                another_line = false;
+            } else {
+                line.pop_back(); // Remove the continue character
+            }
+        } else {
+            if (!is_eof_multiline_toggled) {
+                another_line = false;
+            }
+        }
+        input_text += line;
+        if (another_line) {
+            input_text += '\n'; // Append the line to the result
+        }
+    } while (another_line);
+    return true;
+}
index 1ea6f74451811130a378892f5f5a0ff6927633be..df8e4c6ccb990186c1be7793fa436f6e414be210 100644 (file)
 //
 
 struct gpt_params {
-    int32_t seed          = -1;   // RNG seed
-    int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t n_predict     = 128;  // new tokens to predict
-    int32_t repeat_last_n = 64;   // last n tokens to penalize
-    int32_t n_parts       = -1;   // amount of model parts (-1 = determine from model dimensions)
-    int32_t n_ctx         = 512;  // context size
-    int32_t n_batch       = 8;    // batch size for prompt processing
-    int32_t n_keep        = 0;    // number of tokens to keep from initial prompt
+    int32_t seed          = -1;    // RNG seed
+    int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency()); // max 4 threads (default)
+    int32_t n_predict     = 128;   // new tokens to predict
+    int32_t repeat_last_n = 64;    // last n tokens to penalize
+    int32_t n_parts       = -1;    // amount of model parts (-1 = determine from model dimensions)
+    int32_t n_ctx         = 512;   // context size
+    int32_t n_batch       = 8;     // batch size for prompt processing
+    int32_t n_keep        = 0;     // number of tokens to keep from initial prompt (-1 for all)
 
     // sampling parameters
     int32_t top_k = 40;
@@ -33,8 +33,15 @@ struct gpt_params {
     std::string prompt = "";
     std::string input_prefix = ""; // string to prefix user inputs with
 
+    std::string instruct_prefix = ""; // prefix user inputs with tokenized string
+    bool instruct_prefix_bos = false; // prepend bos token to instruct prefix
+    std::string instruct_suffix = ""; // suffix user inputs with tokenized string
+    bool instruct_suffix_bos = false; // prepend bos token to instruct suffix
 
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
+    std::vector<std::string> stopprompt; // string upon seeing which more user input is prompted (without adding instruct prefixes and suffixes)
+
+    bool rm_trailing_space_workaround = false; // workaround for removing trailing space from reverse/stop prompts
 
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
@@ -51,11 +58,14 @@ struct gpt_params {
     bool use_mlock         = false; // use mlock to keep model in memory
     bool mem_test          = false; // compute maximum memory usage
     bool verbose_prompt    = false; // print prompt tokens before generation
+
+    bool clean_interface   = false; // hides input prefix & suffix and displays '>'
+    bool multiline_mode    = true; // enables multi-line mode, to send input press CTRL+D on Linux/Max, Ctrl+Z then Return on Windows
 };
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 
-void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+void gpt_print_usage(char * argv_0, const gpt_params & params);
 
 std::string gpt_random_prompt(std::mt19937 & rng);
 
@@ -95,3 +105,5 @@ void set_console_color(console_state & con_st, console_color_t color);
 void win32_console_init(bool enable_color);
 void win32_utf8_encode(const std::wstring & wstr, std::string & str);
 #endif
+
+bool get_input_text(std::string & input_text, bool escape_newline_mode);
index ba153cb82dcf672cd3954f22531dc36ef13d83d1..68b4b2840858e7caba10837d55120c27337cd1b0 100644 (file)
@@ -30,7 +30,8 @@ static bool is_interacting = false;
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
 void sigint_handler(int signo) {
     set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
-    printf("\n"); // this also force flush stdout.
+    fflush(stdout);
+    fflush(stderr);
     if (signo == SIGINT) {
         if (!is_interacting) {
             is_interacting=true;
@@ -89,6 +90,8 @@ int main(int argc, char ** argv) {
         params.prompt = gpt_random_prompt(rng);
     }
 
+    bool instruct_mode = !params.instruct_prefix.empty() || !params.instruct_suffix.empty();
+
 //    params.prompt = R"(// this function checks if the number n is prime
 //bool is_prime(int n) {)";
 
@@ -153,22 +156,20 @@ int main(int argc, char ** argv) {
     }
 
     // number of tokens to keep when resetting context
-    if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
+    if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size()) {
         params.n_keep = (int)embd_inp.size();
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
-    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
-
-    // in instruct mode, we inject a prefix and a suffix to each input by the user
-    if (params.instruct) {
-        params.interactive_start = true;
-        params.antiprompt.push_back("### Instruction:\n\n");
+    const auto inp_pfx = ::llama_tokenize(ctx, params.instruct_prefix, params.instruct_prefix_bos);
+    std::string instruct_suffix = params.instruct_suffix;
+    if (params.rm_trailing_space_workaround) {
+        if (instruct_suffix.back() == ' ') { instruct_suffix.pop_back(); }
     }
+    const auto inp_sfx = ::llama_tokenize(ctx, instruct_suffix, params.instruct_suffix_bos);
 
     // enable interactive mode if reverse prompt or interactive start is specified
-    if (params.antiprompt.size() != 0 || params.interactive_start) {
+    if (params.antiprompt.size() != 0 || params.stopprompt.size() != 0 || params.interactive_start) {
         params.interactive = true;
     }
 
@@ -210,10 +211,21 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
             }
         }
+        if (params.stopprompt.size()) {
+            for (auto stopprompt : params.stopprompt) {
+                fprintf(stderr, "Stop prompt: '%s'\n", stopprompt.c_str());
+            }
+        }
 
         if (!params.input_prefix.empty()) {
             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
         }
+        if (!params.instruct_prefix.empty()) {
+            fprintf(stderr, "Instruct prefix %s: '%s'\n", params.instruct_prefix_bos ? "(with bos token)" : "", params.instruct_prefix.c_str());
+        }
+        if (!params.instruct_suffix.empty()) {
+            fprintf(stderr, "Instruct suffix %s: '%s'\n", params.instruct_suffix_bos ? "(with bos token)" : "", params.instruct_suffix.c_str());
+        }
     }
     fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
         params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
@@ -229,12 +241,29 @@ int main(int argc, char ** argv) {
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
                " - Press Ctrl+C to interject at any time.\n"
 #endif
-               " - Press Return to return control to LLaMa.\n"
-               " - If you want to submit another line, end your input in '\\'.\n\n");
+        );
+        if (params.multiline_mode) {
+            fprintf(stderr, " - Press Return to return control to LLaMa.\n"
+#if defined (_WIN32)
+                            " - [MULTILINE MODE] Press Ctrl+Z then Return (EOF) to toggle.\n\n");
+#else
+                            " - [MULTILINE MODE] Press Ctrl+D (EOF) to toggle.\n\n");
+#endif
+        }
+        else {
+            fprintf(stderr, " - Press Return to return control to LLaMa.\n"
+                            " - If you want to submit another line, end your input in '\\'.\n\n");
+        }
         is_interacting = params.interactive_start;
     }
 
-    bool is_antiprompt = false;
+    struct Antiprompt {
+        bool any = false;
+        bool trailing_space = false;
+        size_t len;
+        bool is_stop_prompt = false;
+    } antiprompt;
+
     bool input_noecho  = false;
 
     int n_past     = 0;
@@ -304,7 +333,7 @@ int main(int argc, char ** argv) {
             }
 
             // replace end of text token with newline token when in interactive mode
-            if (id == llama_token_eos() && params.interactive && !params.instruct) {
+            if (id == llama_token_eos() && params.interactive && !instruct_mode) {
                 id = llama_token_newline.front();
                 if (params.antiprompt.size() != 0) {
                     // tokenize and inject first reverse prompt
@@ -350,27 +379,72 @@ int main(int argc, char ** argv) {
         // check if we should prompt the user for more
         if (params.interactive && (int) embd_inp.size() <= n_consumed) {
 
-            // check for reverse prompt
-            if (params.antiprompt.size()) {
+            // check for reverse prompt or stop prompt
+            if (params.antiprompt.size() || params.stopprompt.size()) {
                 std::string last_output;
                 for (auto id : last_n_tokens) {
                     last_output += llama_token_to_str(ctx, id);
                 }
 
-                is_antiprompt = false;
+                antiprompt.any = false;
+                antiprompt.is_stop_prompt = false;
                 // Check if each of the reverse prompts appears at the end of the output.
-                for (std::string & antiprompt : params.antiprompt) {
-                    if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
+                for (std::string & prompt : params.antiprompt) {
+                    if (params.rm_trailing_space_workaround) {
+                        antiprompt.trailing_space = prompt.back() == ' ';
+                        antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0);
+                    }
+                    if (last_output.find(prompt.c_str(), last_output.length() - antiprompt.len, antiprompt.len) != std::string::npos) {
                         is_interacting = true;
-                        is_antiprompt = true;
+                        antiprompt.any = true;
                         set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
                         fflush(stdout);
                         break;
                     }
                 }
+                if (!antiprompt.any) {
+                    for (std::string & prompt : params.stopprompt) {
+                        if (params.rm_trailing_space_workaround) {
+                            antiprompt.trailing_space = prompt.back() == ' ';
+                            antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0);
+                        }
+                        if (last_output.find(prompt.c_str(), last_output.length() - antiprompt.len, antiprompt.len) != std::string::npos) {
+                            is_interacting = true;
+                            antiprompt.any = true;
+                            antiprompt.is_stop_prompt = true;
+                            set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
+                            fflush(stdout);
+                            break;
+                        }
+                    }
+                }
             }
 
-            if (n_past > 0 && is_interacting) {
+            if (n_past > 0 && is_interacting)
+            {
+                std::string buffer;
+                if (!params.clean_interface && !params.instruct_prefix.empty() && !antiprompt.any) {
+                    // avoid printing again user's new line (TODO: try to revert enter press and print newline)
+                    int i = params.instruct_prefix.front() == '\n' ? 1 : 0;
+                    for (; i < inp_pfx.size(); i++) {
+                        printf("%s", llama_token_to_str(ctx, inp_pfx[i]));
+                    }
+                    fflush(stdout);
+                }
+                if (params.rm_trailing_space_workaround) {
+                    // add only if not stopprompt (as stopprompt could be used to pause
+                        //     assistant and then continue without input - adding back trailing
+                        //     space may mess it up.)
+                    if (!antiprompt.is_stop_prompt && antiprompt.any && antiprompt.trailing_space) {
+                        // add back removed trailing space to buffer(workaround)
+                        buffer += ' ';
+                        if (!params.clean_interface) {
+                            printf("%s", buffer.c_str());
+                        }
+                        fflush(stdout);
+                    }
+                }
+
                 // potentially set color to indicate we are taking user input
                 set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
 
@@ -379,49 +453,45 @@ int main(int argc, char ** argv) {
                 signal(SIGINT, sigint_handler);
 #endif
 
-                if (params.instruct) {
+                if (params.clean_interface) {
                     printf("\n> ");
                 }
 
-                std::string buffer;
                 if (!params.input_prefix.empty()) {
                     buffer += params.input_prefix;
                     printf("%s", buffer.c_str());
                 }
 
-                std::string line;
-                bool another_line = true;
-                do {
-#if defined(_WIN32)
-                    std::wstring wline;
-                    if (!std::getline(std::wcin, wline)) {
-                        // input stream is bad or EOF received
-                        return 0;
-                    }
-                    win32_utf8_encode(wline, line);
-#else
-                    if (!std::getline(std::cin, line)) {
-                        // input stream is bad or EOF received
-                        return 0;
-                    }
-#endif
-                    if (line.empty() || line.back() != '\\') {
-                        another_line = false;
-                    } else {
-                        line.pop_back(); // Remove the continue character
-                    }
-                    buffer += line + '\n'; // Append the line to the result
-                } while (another_line);
+                if (!get_input_text(buffer, params.multiline_mode)) {
+                    // input stream is bad
+                    return 1;
+                }
+                if (!antiprompt.is_stop_prompt) {
+                    buffer += "\n";
+                }
 
                 // done taking input, reset color
                 set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
 
+                if (!params.clean_interface && !params.instruct_suffix.empty() && !antiprompt.is_stop_prompt) {
+                    // avoid printing again user's new line (TODO: try to revert enter press and print newline)
+                    int i = params.instruct_suffix.front() == '\n' ? 1 : 0;
+                    for (; i < inp_sfx.size(); i++) {
+                        printf("%s", llama_token_to_str(ctx, inp_sfx[i]));
+                    }
+                    // if (remove trailing space workaround) {
+                    //     We won't add back removed trailing space here, because assistant continues here,
+                    //         and it may mess up it's output (remove trailing space workaround).
+                    // }
+                    fflush(stdout);
+                }
+
                 // Add tokens to embd only if the input buffer is non-empty
                 // Entering a empty line lets the user pass control back
                 if (buffer.length() > 1) {
 
-                    // instruct mode: insert instruction prefix
-                    if (params.instruct && !is_antiprompt) {
+                    // insert input prefix
+                    if (!params.instruct_prefix.empty() && !antiprompt.any) {
                         n_consumed = embd_inp.size();
                         embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
                     }
@@ -429,8 +499,8 @@ int main(int argc, char ** argv) {
                     auto line_inp = ::llama_tokenize(ctx, buffer, false);
                     embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
 
-                    // instruct mode: insert response suffix
-                    if (params.instruct) {
+                    // insert response suffix
+                    if (!params.instruct_suffix.empty() && !antiprompt.is_stop_prompt) {
                         embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
                     }
 
@@ -447,7 +517,7 @@ int main(int argc, char ** argv) {
 
         // end of text token
         if (!embd.empty() && embd.back() == llama_token_eos()) {
-            if (params.instruct) {
+            if (instruct_mode) {
                 is_interacting = true;
             } else {
                 fprintf(stderr, " [end of text]\n");
diff --git a/prompts/alpaca.txt b/prompts/alpaca.txt
deleted file mode 100644 (file)
index 2224bde..0000000
+++ /dev/null
@@ -1 +0,0 @@
-Below is an instruction that describes a task. Write a response that appropriately completes the request.
diff --git a/prompts/chat-with-bob.txt b/prompts/chat-with-bob.txt
deleted file mode 100644 (file)
index ad494d8..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
-
-User: Hello, Bob.
-Bob: Hello. How may I help you today?
-User: Please tell me the largest city in Europe.
-Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
-User:
\ No newline at end of file