]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add "--instruct" argument for usage with Alpaca (#240)
authorGeorgi Gerganov <redacted>
Sun, 19 Mar 2023 16:37:02 +0000 (18:37 +0200)
committerGeorgi Gerganov <redacted>
Sun, 19 Mar 2023 16:37:02 +0000 (18:37 +0200)
Also start adding prompts in "./prompts"

main.cpp
prompts/alpaca.txt [new file with mode: 0644]
prompts/chat-with-bob.txt [new file with mode: 0644]
utils.cpp
utils.h

index 105dd91ee60654c173a926a6eee39022fb29c3c1..a95e2e72151fc23c522d6511d4eb75d144242815 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -176,8 +176,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
                 }
     }
 
-    const ggml_type wtype2 = GGML_TYPE_F32;
-
     auto & ctx = model.ctx;
 
     size_t ctx_size = 0;
@@ -237,7 +235,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
 
         const int n_embd  = hparams.n_embd;
         const int n_layer = hparams.n_layer;
-        const int n_ctx   = hparams.n_ctx;
         const int n_vocab = hparams.n_vocab;
 
         model.layers.resize(n_layer);
@@ -539,9 +536,7 @@ bool llama_eval(
     const int n_vocab = hparams.n_vocab;
     const int n_rot   = hparams.n_embd/hparams.n_head;
 
-    const int d_key = n_embd/n_head;
-
-     // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
+    // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
     // static size_t buf_size = hparams.n_ctx*1024*1024;
     static size_t buf_size = 512u*1024*1024;
     static void * buf = malloc(buf_size);
@@ -792,7 +787,7 @@ int main(int argc, char ** argv) {
     if (gpt_params_parse(argc, argv, params) == false) {
         return 1;
     }
-    
+
     if (params.n_ctx > 2048) {
         fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
                 "expect poor results\n", __func__, params.n_ctx);
@@ -820,7 +815,7 @@ int main(int argc, char ** argv) {
     // load the model
     {
         const int64_t t_start_us = ggml_time_us();
-        if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {  
+        if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }
@@ -849,9 +844,25 @@ int main(int argc, char ** argv) {
 
     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
 
+    // prefix & suffix for instruct mode
+    const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
+    const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
+
+    // in instruct mode, we inject a prefix and a suffix to each input by the user
+    if (params.instruct) {
+        fprintf(stderr, "== Instruction mode enabled ==\n");
+        params.interactive = true;
+        params.antiprompt = "### Instruction:\n\n";
+    }
+
     // tokenize the reverse prompt
     std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);
 
+    // enable interactive mode if reverse prompt is specified
+    if (!antiprompt_inp.empty()) {
+        params.interactive = true;
+    }
+
     fprintf(stderr, "\n");
     fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
     fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
@@ -872,7 +883,7 @@ int main(int argc, char ** argv) {
 
         fprintf(stderr, "%s: interactive mode on.\n", __func__);
 
-        if(antiprompt_inp.size()) {
+        if (antiprompt_inp.size()) {
             fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
             fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
             for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
@@ -894,31 +905,27 @@ int main(int argc, char ** argv) {
     std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 
-
     if (params.interactive) {
         fprintf(stderr, "== Running in interactive mode. ==\n"
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
                " - Press Ctrl+C to interject at any time.\n"
 #endif
                " - Press Return to return control to LLaMa.\n"
-               " - If you want to submit another line, end your input in '\\'.\n");
+               " - If you want to submit another line, end your input in '\\'.\n\n");
+        is_interacting = true;
     }
 
-    int remaining_tokens = params.n_predict;
     int input_consumed = 0;
     bool input_noecho = false;
 
-    // prompt user immediately after the starting prompt has been loaded
-    if (params.interactive_start) {
-        is_interacting = true;
-    }
+    int remaining_tokens = params.n_predict;
 
     // set the color for the prompt which will be output initially
     if (params.use_color) {
         printf(ANSI_COLOR_YELLOW);
     }
 
-    while (remaining_tokens > 0) {
+    while (remaining_tokens > 0 || params.interactive) {
         // predict
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
@@ -971,13 +978,13 @@ int main(int argc, char ** argv) {
                 last_n_tokens.erase(last_n_tokens.begin());
                 last_n_tokens.push_back(embd_inp[input_consumed]);
                 ++input_consumed;
-                if (embd.size() > params.n_batch) {
+                if ((int) embd.size() > params.n_batch) {
                     break;
                 }
             }
 
             // reset color to default if we there is no pending user input
-            if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) {
+            if (!input_noecho && params.use_color && (int) embd_inp.size() == input_consumed) {
                 printf(ANSI_COLOR_RESET);
             }
         }
@@ -999,19 +1006,26 @@ int main(int argc, char ** argv) {
                 is_interacting = true;
             }
             if (is_interacting) {
+                if (params.instruct) {
+                    input_consumed = embd_inp.size();
+                    embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
+
+                    printf("\n> ");
+                }
+
                 // currently being interactive
-                bool another_line=true;
+                bool another_line = true;
                 while (another_line) {
                     fflush(stdout);
                     char buf[256] = {0};
                     int n_read;
-                    if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
+                    if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
                     if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
                         // presumable empty line, consume the newline
                         std::ignore = scanf("%*c");
                         n_read=0;
                     }
-                    if(params.use_color) printf(ANSI_COLOR_RESET);
+                    if (params.use_color) printf(ANSI_COLOR_RESET);
 
                     if (n_read > 0 && buf[n_read-1]=='\\') {
                         another_line = true;
@@ -1026,6 +1040,10 @@ int main(int argc, char ** argv) {
                     std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
                     embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
 
+                    if (params.instruct) {
+                        embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
+                    }
+
                     remaining_tokens -= line_inp.size();
 
                     input_noecho = true; // do not echo this again
@@ -1037,8 +1055,12 @@ int main(int argc, char ** argv) {
 
         // end of text token
         if (embd.back() == 2) {
-            fprintf(stderr, " [end of text]\n");
-            break;
+            if (params.interactive) {
+                is_interacting = true;
+            } else {
+                fprintf(stderr, " [end of text]\n");
+                break;
+            }
         }
     }
 
diff --git a/prompts/alpaca.txt b/prompts/alpaca.txt
new file mode 100644 (file)
index 0000000..2224bde
--- /dev/null
@@ -0,0 +1 @@
+Below is an instruction that describes a task. Write a response that appropriately completes the request.
diff --git a/prompts/chat-with-bob.txt b/prompts/chat-with-bob.txt
new file mode 100644 (file)
index 0000000..009da39
--- /dev/null
@@ -0,0 +1,7 @@
+Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
+
+User: Hello, Bob.
+Bob: Hello. How may I help you today?
+User: Please tell me the largest city in Europe.
+Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
+User:
index efa2e3c35f728e7addb7806b8bcf7bfc6430f0ab..be81c6cd08cfcdd10a43898eeb8d1e181f00cb06 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -38,13 +38,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "-p" || arg == "--prompt") {
             params.prompt = argv[++i];
         } else if (arg == "-f" || arg == "--file") {
-
             std::ifstream file(argv[++i]);
-
-            std::copy(std::istreambuf_iterator<char>(file),
-                    std::istreambuf_iterator<char>(),
-                    back_inserter(params.prompt));
-                
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
         } else if (arg == "-n" || arg == "--n_predict") {
             params.n_predict = std::stoi(argv[++i]);
         } else if (arg == "--top_k") {
@@ -65,9 +60,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.model = argv[++i];
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
-        } else if (arg == "--interactive-start") {
-            params.interactive = true;
-            params.interactive_start = true;
+        } else if (arg == "-ins" || arg == "--instruct") {
+            params.instruct = true;
         } else if (arg == "--color") {
             params.use_color = true;
         } else if (arg == "-r" || arg == "--reverse-prompt") {
@@ -85,13 +79,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     return true;
 }
 
-void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
+void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "usage: %s [options]\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
     fprintf(stderr, "  -i, --interactive     run in interactive mode\n");
-    fprintf(stderr, "  --interactive-start   run in interactive mode and poll user input at startup\n");
+    fprintf(stderr, "  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
     fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n");
     fprintf(stderr, "                        in interactive mode, poll user input upon seeing PROMPT\n");
     fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n");
@@ -398,7 +392,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
                     logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
                 } else {
                     logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
-                }                
+                }
             } else {
                 logits_id.push_back(std::make_pair(logits[i]*scale, i));
             }
diff --git a/utils.h b/utils.h
index c1a8498a78d68dfdfd2fbffb69f4cd8187823827..e329ba168b45fd929a0b86e65db98d7cd3d34f26 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -27,14 +27,14 @@ struct gpt_params {
 
     int32_t n_batch = 8; // batch size for prompt processing
 
-    std::string model = "models/lamma-7B/ggml-model.bin"; // model path
-    std::string prompt;
+    std::string model      = "models/lamma-7B/ggml-model.bin"; // model path
+    std::string prompt     = "";
+    std::string antiprompt = ""; // string upon seeing which more user input is prompted
 
     bool use_color = false; // use color to distinguish generations and inputs
 
     bool interactive = false; // interactive mode
-    bool interactive_start = false; // reverse prompt immediately
-    std::string antiprompt = ""; // string upon seeing which more user input is prompted
+    bool instruct    = false; // instruction mode (used for Alpaca models)
 };
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);