]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : option to disable context shift (#9484)
authorVinesh Janarthanan <redacted>
Mon, 16 Sep 2024 06:20:01 +0000 (01:20 -0500)
committerGitHub <redacted>
Mon, 16 Sep 2024 06:20:01 +0000 (09:20 +0300)
* added cli arg to disable context shift

* reverted precommit

* updated README.md for main

* white space

* allow disabling context shift in the server

* Update common/arg.cpp

no-context-shift only works for main example

Co-authored-by: Georgi Gerganov <redacted>
* added server example to --no-context-shift args

* removed server changes

* white space

---------

Co-authored-by: Georgi Gerganov <redacted>
common/arg.cpp
common/common.h
examples/main/README.md
examples/main/main.cpp

index 8fcb8c25f862be9df566651cc330e48667e0fc85..60e37a89a68e8599e124181fe9dff948c9ce47b4 100644 (file)
@@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
             params.n_keep = value;
         }
     ));
+    add_opt(llama_arg(
+        {"--no-context-shift"},
+        format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
+        [](gpt_params & params) {
+            params.ctx_shift = false;
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN}));
     add_opt(llama_arg(
         {"--chunks"}, "N",
         format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
 
     return ctx_arg;
 }
-
index e100c8fa73ecd81ed03be5a55ce6c36ce64f777c..cb87c4479ed0a029c8e29e062260c020d94d347d 100644 (file)
@@ -246,6 +246,7 @@ struct gpt_params {
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
     bool flash_attn        = false; // flash attention
     bool no_perf           = false; // disable performance metrics
+    bool ctx_shift         = true;  // context shift on inifinite text generation
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool logits_all        = false; // return logits for all tokens in the batch
index 9396a34fa5a31d36bad9191b90be00b6e196efee..6730effdf2d661d070254cbdc8e07d83bb7693c4 100644 (file)
@@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite
 
 If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
 
+The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
+
 It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
 
 ### Temperature
index d9e45ce2fb5374777a0ae879538d85291df029eb..91fea93266c860af8145c5d19082139de24433d7 100644 (file)
@@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
                 // if we run out of context:
                 // - take the n_keep first tokens from the original prompt (via n_past)
                 // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
+
                 if (n_past + (int) embd.size() >= n_ctx) {
-                    if (params.n_predict == -2) {
-                        LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                    if (!params.ctx_shift){
+                        LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
                         break;
-                    }
+                    } else {
+                        if (params.n_predict == -2) {
+                            LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                            break;
+                        }
 
-                    const int n_left    = n_past - params.n_keep;
-                    const int n_discard = n_left/2;
+                        const int n_left    = n_past - params.n_keep;
+                        const int n_discard = n_left/2;
 
-                    LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
-                            n_past, n_left, n_ctx, params.n_keep, n_discard);
+                        LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+                                n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                        llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                        llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
-                    n_past -= n_discard;
+                        n_past -= n_discard;
 
-                    LOG_DBG("after swap: n_past = %d\n", n_past);
+                        LOG_DBG("after swap: n_past = %d\n", n_past);
 
-                    LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
+                        LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
 
-                    LOG_DBG("clear session path\n");
-                    path_session.clear();
+                        LOG_DBG("clear session path\n");
+                        path_session.clear();
+                    }
                 }
             } else {
                 // context extension via Self-Extend