]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
main : add self-extend support (#4815)
authorGeorgi Gerganov <redacted>
Mon, 8 Jan 2024 09:18:32 +0000 (11:18 +0200)
committerGitHub <redacted>
Mon, 8 Jan 2024 09:18:32 +0000 (11:18 +0200)
* examples : add passkey test

* passkey : better prints

* passkey : select pass key pos from CLI

* passkey : simplify n_past logic

* llama : "self-extend"-like context extension

* passkey : add comment

* main : add Self-Extend support

* llama : add comment about llama_kv_cache_seq_div

common/common.cpp
common/common.h
examples/main/main.cpp
llama.h

index eacaee18e09071d9b41facf1453fadbc0e573b23..6b4913a6565731e862969d2c823469117fb87999 100644 (file)
@@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_ctx = std::stoi(argv[i]);
+        } else if (arg == "--grp-attn-n" || arg == "-gan") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+
+            params.grp_attn_n = std::stoi(argv[i]);
+        } else if (arg == "--grp-attn-w" || arg == "-gaw") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+
+            params.grp_attn_w = std::stoi(argv[i]);
         } else if (arg == "--rope-freq-base") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        Not recommended since this is both slower and uses more VRAM.\n");
 #endif // GGML_USE_CUBLAS
 #endif
+    printf("  -gan N, --grp-attn-n N\n");
+    printf("                        group-attention factor (default: %d)\n", params.grp_attn_n);
+    printf("  -gat N, --grp-attn-w N\n");
+    printf("                        group-attention width (default: %.1f)\n", (double)params.grp_attn_w);
     printf("  --verbose-prompt      print prompt before generation\n");
     printf("  -dkvc, --dump-kv-cache\n");
     printf("                        verbose print of the KV cache\n");
index 9659aa0453ff8135e926eae6d83e660a9848a5d5..e2bbfc258b6467cb24e5d40a6e28cd54ab148368 100644 (file)
@@ -62,6 +62,8 @@ struct gpt_params {
     int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
     float   tensor_split[LLAMA_MAX_DEVICES] = {0};   // how split tensors should be distributed across GPUs
     int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
+    int32_t grp_attn_n                      = 1;     // group-attention factor
+    int32_t grp_attn_w                      = 512;   // group-attention width
     float   rope_freq_base                  = 0.0f;  // RoPE base frequency
     float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
     float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
index c096f110b32c55f19959dba82820dc36c38369c8..5ea67051f36546aa5d11c6c7fb80489a94e44252 100644 (file)
@@ -439,6 +439,21 @@ int main(int argc, char ** argv) {
     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
     LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
+
+    // group-attention state
+    // number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
+    int ga_i = 0;
+
+    const int ga_n = params.grp_attn_n;
+    const int ga_w = params.grp_attn_w;
+
+    if (ga_n != 1) {
+        GGML_ASSERT(ga_n > 0                    && "grp_attn_n must be positive");                     // NOLINT
+        GGML_ASSERT(ga_w % ga_n == 0            && "grp_attn_w must be a multiple of grp_attn_n");     // NOLINT
+      //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of grp_attn_w");    // NOLINT
+      //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
+        LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
+    }
     LOG_TEE("\n\n");
 
     if (params.interactive) {
@@ -500,37 +515,61 @@ int main(int argc, char ** argv) {
                 fflush(stdout);
             }
 
-            // infinite text generation via context swapping
-            // if we run out of context:
-            // - take the n_keep first tokens from the original prompt (via n_past)
-            // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
-            if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
-                if (params.n_predict == -2) {
-                    LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
-                    break;
-                }
+            if (ga_n == 1) {
+                // infinite text generation via context shifting
+                // if we run out of context:
+                // - take the n_keep first tokens from the original prompt (via n_past)
+                // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
+                if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
+                    if (params.n_predict == -2) {
+                        LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                        break;
+                    }
 
-                const int n_left    = n_past - params.n_keep - 1;
-                const int n_discard = n_left/2;
+                    const int n_left    = n_past - params.n_keep - 1;
+                    const int n_discard = n_left/2;
 
-                LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
-                    n_past, n_left, n_ctx, params.n_keep, n_discard);
+                    LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+                            n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                    llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
+                    llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 
-                n_past -= n_discard;
+                    n_past -= n_discard;
 
-                if (ctx_guidance) {
-                    n_past_guidance -= n_discard;
+                    if (ctx_guidance) {
+                        n_past_guidance -= n_discard;
+                    }
+
+                    LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+
+                    LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
+
+                    LOG("clear session path\n");
+                    path_session.clear();
                 }
+            } else {
+                // context extension via Self-Extend
+                while (n_past >= ga_i + ga_w) {
+                    const int ib = (ga_n*ga_i)/ga_w;
+                    const int bd = (ga_w/ga_n)*(ga_n - 1);
+                    const int dd = (ga_w/ga_n) - ib*bd - ga_w;
 
-                LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+                    LOG("\n");
+                    LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
+                    LOG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
+                    LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
+                    llama_kv_cache_seq_shift(ctx, 0, ga_i,                n_past,              ib*bd);
+                    llama_kv_cache_seq_div  (ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
-                LOG("clear session path\n");
-                path_session.clear();
+                    n_past -= bd;
+
+                    ga_i += ga_w/ga_n;
+
+                    LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
+                }
             }
 
             // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
diff --git a/llama.h b/llama.h
index 5305de90be5c1faede855c4dbecd53ef158a60c1..869ff0acf525a825763c2353e01b3bf42ed0ee09 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -484,6 +484,10 @@ extern "C" {
                        llama_pos   p1,
                        llama_pos   delta);
 
+    // Integer division of the positions by factor of `d > 1`
+    // If the KV cache is RoPEd, the KV data is updated accordingly
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,