]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : add option to save full output to session (#1338)
authorEvan Jones <redacted>
Wed, 10 May 2023 15:37:14 +0000 (11:37 -0400)
committerGitHub <redacted>
Wed, 10 May 2023 15:37:14 +0000 (11:37 -0400)
* main : add option to save full output to session

* split behavior into --session and --prompt-cache

* restore original implementation with new names

* PR comments

* move the check for incompatible parameters to gpt_params_parse

* Fix whitespace

Co-authored-by: DannyDaemonic <redacted>
---------

Co-authored-by: DannyDaemonic <redacted>
examples/common.cpp
examples/common.h
examples/main/README.md
examples/main/main.cpp

index 7aa77587b4605f08a9f277009a2f43f0a44d450c..f3085b08e5b25e2d086fd264174f8422b418377c 100644 (file)
@@ -118,12 +118,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.prompt = argv[i];
         } else if (arg == "-e") {
             escape_prompt = true;
-        } else if (arg == "--session") {
+        } else if (arg == "--prompt-cache") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
-            params.path_session = argv[i];
+            params.path_prompt_cache = argv[i];
+        } else if (arg == "--prompt-cache-all") {
+            params.prompt_cache_all = true;
         } else if (arg == "-f" || arg == "--file") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -342,6 +344,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         gpt_print_usage(argc, argv, default_params);
         exit(1);
     }
+    if (params.prompt_cache_all &&
+            (params.interactive || params.interactive_first ||
+             params.instruct || params.antiprompt.size())) {
+        fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
+        gpt_print_usage(argc, argv, default_params);
+        exit(1);
+    }
     if (escape_prompt) {
         process_escapes(params.prompt);
     }
@@ -367,7 +376,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "                        prompt to start generation with (default: empty)\n");
     fprintf(stderr, "  -e                    process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
-    fprintf(stderr, "  --session FNAME       file to cache model state in (may be large!) (default: none)\n");
+    fprintf(stderr, "  --prompt-cache FNAME  file to cache prompt state for faster startup (default: none)\n");
+    fprintf(stderr, "  --prompt-cache-all    if specified, saves user input and generations to cache as well.\n");
+    fprintf(stderr, "                        not supported with --interactive or other interactive options\n");
     fprintf(stderr, "  --random-prompt       start with a randomized prompt.\n");
     fprintf(stderr, "  --in-prefix STRING    string to prefix user inputs with (default: empty)\n");
     fprintf(stderr, "  --in-suffix STRING    string to suffix after user inputs with (default: empty)\n");
index 43f1cc9ef09d57c31776cb2c4ae187b635a9ec59..499671b2e8d6dccf067c345dde7d4745985649a2 100644 (file)
@@ -46,9 +46,9 @@ struct gpt_params {
 
     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt = "";
-    std::string path_session = "";       // path to file for saving/loading model eval state
-    std::string input_prefix = "";       // string to prefix user inputs with
-    std::string input_suffix = "";       // string to suffix user inputs with
+    std::string path_prompt_cache = "";  // path to file for saving/loading prompt eval state
+    std::string input_prefix      = "";  // string to prefix user inputs with
+    std::string input_suffix      = "";  // string to suffix user inputs with
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
 
     std::string lora_adapter = "";  // lora adapter path
@@ -58,6 +58,7 @@ struct gpt_params {
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
+    bool prompt_cache_all  = false; // save user input and generations to prompt cache
 
     bool embedding         = false; // get only sentence embedding
     bool interactive_first = false; // wait for user input immediately
index 35f87bcd594edac0db7cffcefc91c4ada3cbd5ae..7c03f92c897d9c7577fd01c1b1e8d4cc25cd8633 100644 (file)
@@ -270,9 +270,9 @@ These options help improve the performance and memory usage of the LLaMA models.
 
 -   `-b N, --batch_size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
 
-### Session Caching
+### Prompt Caching
 
--   `--session FNAME`: Specify a file to load/save the session, which caches the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The session file is created during the first run and is reused in subsequent runs. If you change your prompt such that 75% or less of the session is reusable, the existing session file will be overwritten with a new, updated version to maintain optimal performance.
+-   `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
 
 ### Quantization
 
index 6e1172a48367d1b189eafdc8f352a0e409f7db83..bd1c4ab5585212e2db196d88106fe04557a54f2b 100644 (file)
@@ -139,7 +139,7 @@ int main(int argc, char ** argv) {
     // Add a space in front of the first character to match OG llama tokenizer behavior
     params.prompt.insert(0, 1, ' ');
 
-    std::string path_session = params.path_session;
+    std::string path_session = params.path_prompt_cache;
     std::vector<llama_token> session_tokens;
 
     if (!path_session.empty()) {
@@ -292,14 +292,9 @@ int main(int argc, char ** argv) {
         is_interacting = params.interactive_first;
     }
 
-    bool is_antiprompt = false;
-    bool input_echo    = true;
-
-    // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
-    // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
-    // initial prompt so it doesn't need to be an exact match.
-    bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
-
+    bool is_antiprompt        = false;
+    bool input_echo           = true;
+    bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
 
     int n_past             = 0;
     int n_remain           = params.n_predict;
@@ -328,7 +323,7 @@ int main(int argc, char ** argv) {
                 embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
 
                 // stop saving session if we run out of context
-                path_session = "";
+                path_session.clear();
 
                 //printf("\n---\n");
                 //printf("resetting: '");
@@ -603,6 +598,11 @@ int main(int argc, char ** argv) {
         }
     }
 
+    if (!path_session.empty() && params.prompt_cache_all) {
+        fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
+        llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+    }
+
     llama_print_timings(ctx);
     llama_free(ctx);