]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add self-extend support (#5104)
authorMaximilian Winter <redacted>
Sat, 27 Jan 2024 13:38:05 +0000 (14:38 +0100)
committerGitHub <redacted>
Sat, 27 Jan 2024 13:38:05 +0000 (15:38 +0200)
* Ported self extension to server example

* Update server.cpp

* Fixed prompt caching without self extend

* Update server.cpp

* Added description to server readme.

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update README.md

* Changed descriptions

* server : formatting

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update server.cpp

* Update server.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/README.md
examples/server/server.cpp

index fd3034b99c3d2a18d8b4f313de1a5f39524fa7a9..1c92a2041c43032fc9c91fadeb52e7d91a51a370 100644 (file)
@@ -30,7 +30,8 @@ Command line options:
 -   `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
 -   `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
 -   `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
-
+-   `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`
+-   `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`
 ## Build
 
 server is build alongside everything else from the root of the project
index 39283613256eda655a85408900233654e48f69fe..af63f2f6f7d125105b1fbbe6ceca23932eb58f4b 100644 (file)
@@ -184,6 +184,12 @@ struct llama_client_slot
     struct llama_sampling_params sparams;
     llama_sampling_context *ctx_sampling = nullptr;
 
+    int32_t ga_i = 0;   // group-attention state
+    int32_t ga_n = 1;// group-attention factor
+    int32_t ga_w = 512; // group-attention width
+
+    int32_t n_past_se = 0; // self-extend
+
     // multimodal
     std::vector<slot_image> images;
 
@@ -212,7 +218,8 @@ struct llama_client_slot
         sent_count             = 0;
         sent_token_probs_index = 0;
         infill                 = false;
-
+        ga_i                   = 0;
+        n_past_se  = 0;
         generated_token_probs.clear();
 
         for (slot_image & img : images)
@@ -399,9 +406,26 @@ struct llama_server_context
 
             slot.id = i;
             slot.n_ctx = n_ctx_slot;
-            slot.reset();
 
             LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
+
+            const int ga_n = params.grp_attn_n;
+            const int ga_w = params.grp_attn_w;
+
+            if (ga_n != 1) {
+                GGML_ASSERT(ga_n > 0                    && "ga_n must be positive");                     // NOLINT
+                GGML_ASSERT(ga_w % ga_n == 0            && "ga_w must be a multiple of ga_n");     // NOLINT
+                //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of ga_w");    // NOLINT
+                //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
+                LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
+            }
+
+            slot.ga_i = 0;
+            slot.ga_n = ga_n;
+            slot.ga_w = ga_w;
+
+            slot.reset();
+
             slots.push_back(slot);
         }
 
@@ -1349,32 +1373,35 @@ struct llama_server_context
 
         for (llama_client_slot &slot : slots)
         {
-            if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
+            if (slot.ga_n == 1)
             {
-                // Shift context
-                const int n_left    = slot.n_past - slot.params.n_keep - 1;
-                const int n_discard = n_left / 2;
+                if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
+                {
+                    // Shift context
+                    const int n_left    = slot.n_past - slot.params.n_keep - 1;
+                    const int n_discard = n_left / 2;
 
-                LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
-                llama_kv_cache_seq_rm   (ctx, slot.id, slot.params.n_keep + 1            , slot.params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
+                    LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
+                    llama_kv_cache_seq_rm   (ctx, slot.id, slot.params.n_keep + 1            , slot.params.n_keep + n_discard + 1);
+                    llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
 
-                for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
-                {
-                    slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
-                }
+                    for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
+                    {
+                        slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
+                    }
 
-                slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
+                    slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
 
-                slot.n_past -= n_discard;
+                    slot.n_past -= n_discard;
 
-                slot.truncated = true;
+                    slot.truncated = true;
 
-                LOG_VERBOSE("context shift", {
-                                                {"n_ctx",  n_ctx},
-                                                {"n_keep", params.n_keep},
-                                                {"n_left", n_left},
-                                            });
+                    LOG_VERBOSE("context shift", {
+                        { "n_ctx", n_ctx },
+                        { "n_keep", params.n_keep },
+                        { "n_left", n_left },
+                    });
+                }
             }
         }
 
@@ -1401,7 +1428,8 @@ struct llama_server_context
 
             slot.i_batch = batch.n_tokens;
 
-            llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
+            const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
+            llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
 
             slot.n_past += 1;
         }
@@ -1499,6 +1527,8 @@ struct llama_server_context
                         llama_sampling_reset(slot.ctx_sampling);
 
                         slot.n_past = 0;
+                        slot.n_past_se = 0;
+                        slot.ga_i = 0;
                         slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
                     }
                     else
@@ -1512,6 +1542,25 @@ struct llama_server_context
                         slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
                         slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
 
+                        if (slot.ga_n != 1)
+                        {
+                            int ga_i = 0;
+                            int32_t ga_n = slot.ga_n;
+                            int32_t ga_w = slot.ga_w;
+                            int32_t slot_npast = 0;
+                            for (int k = 0; k < slot.n_past; ++k)
+                            {
+                                while (slot_npast >= ga_i + ga_w) {
+                                    const int bd = (ga_w/ga_n)*(ga_n - 1);
+                                    slot_npast -= bd;
+                                    ga_i += ga_w/ga_n;
+                                }
+                                slot_npast++;
+                            }
+                            slot.n_past_se = slot_npast;
+                            slot.ga_i = ga_i;
+                        }
+
                         LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
                     }
 
@@ -1526,6 +1575,10 @@ struct llama_server_context
                         // we have to evaluate at least 1 token to generate logits.
                         LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
                         slot.n_past--;
+                        if (slot.ga_i > 0)
+                        {
+                            slot.n_past_se--;
+                        }
                     }
 
                     LOG_VERBOSE("prompt ingested", {
@@ -1538,9 +1591,22 @@ struct llama_server_context
 
                     // process the prefix of first image
                     std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
+                    int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
+                    int ga_i = slot.ga_i;
+                    int32_t ga_n = slot.ga_n;
+                    int32_t ga_w = slot.ga_w;
                     for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
                     {
-                       llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
+                        if (slot.ga_n != 1)
+                        {
+                            while (slot_npast >= ga_i + ga_w) {
+                                const int bd = (ga_w/ga_n)*(ga_n - 1);
+                                slot_npast -= bd;
+                                ga_i += ga_w/ga_n;
+                            }
+                        }
+                        llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
+                        slot_npast += 1;
                     }
 
                     if (has_images && !ingest_images(slot, n_batch))
@@ -1570,6 +1636,36 @@ struct llama_server_context
         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
         {
             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+
+            for (auto & slot : slots)
+            {
+                if (slot.ga_n != 1)
+                {
+                    // context extension via Self-Extend
+                    while (slot.n_past_se >= slot.ga_i + slot.ga_w)
+                    {
+                        const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
+                        const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
+                        const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
+
+                        LOG_TEE("\n");
+                        LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
+                        LOG_TEE("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
+                        LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
+
+                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
+                        llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
+                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
+
+                        slot.n_past_se -= bd;
+
+                        slot.ga_i += slot.ga_w / slot.ga_n;
+
+                        LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
+                    }
+                    slot.n_past_se += n_tokens;
+                }
+            }
             llama_batch batch_view =
             {
                 n_tokens,
@@ -1583,6 +1679,7 @@ struct llama_server_context
             };
 
             const int ret = llama_decode(ctx, batch_view);
+
             if (ret != 0)
             {
                 if (n_batch == 1 || ret < 0)
@@ -1728,6 +1825,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  --override-kv KEY=TYPE:VALUE\n");
     printf("                        advanced option to override model metadata by key. may be specified multiple times.\n");
     printf("                        types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
+    printf("  -gan N, --grp-attn-n N    Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
+    printf("  -gaw N, --grp-attn-w N    Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
     printf("\n");
 }
 
@@ -1913,6 +2012,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.n_threads = std::stoi(argv[i]);
         }
+        else if (arg == "--grp-attn-n" || arg == "-gan")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+
+            params.grp_attn_n = std::stoi(argv[i]);
+        }
+        else if (arg == "--grp-attn-w" || arg == "-gaw")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+
+            params.grp_attn_w = std::stoi(argv[i]);
+        }
         else if (arg == "--threads-batch" || arg == "-tb")
         {
             if (++i >= argc)