]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix defrag bugs + add parameter (#5735)
authorGeorgi Gerganov <redacted>
Tue, 27 Feb 2024 12:35:51 +0000 (14:35 +0200)
committerGitHub <redacted>
Tue, 27 Feb 2024 12:35:51 +0000 (14:35 +0200)
* llama : fix defrag bugs + enable by default

ggml-ci

* llama : add defrag_thold parameter

ggml-ci

* llama : cont

* llama : disable log message

ggml-ci

* llama : fix graph size check during defrag

common/common.cpp
common/common.h
examples/passkey/passkey.cpp
llama.cpp
llama.h

index ec596f5a075deff14f7faf48ba61caa3e5da0467..18289755c9cebbf28399819bd08734334d447c81 100644 (file)
@@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.yarn_beta_slow = std::stof(argv[i]);
+        } else if (arg == "--defrag-thold" || arg == "-dt") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.defrag_thold = std::stof(argv[i]);
         } else if (arg == "--samplers") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --yarn-attn-factor N  YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
     printf("  --yarn-beta-slow N    YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
     printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+    printf("  -dt N, --defrag-thold N\n");
+    printf("                        KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     printf("  --no-penalize-nl      do not penalize newline token\n");
     printf("  --temp N              temperature (default: %.1f)\n", (double)sparams.temp);
@@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.yarn_beta_fast    = params.yarn_beta_fast;
     cparams.yarn_beta_slow    = params.yarn_beta_slow;
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
+    cparams.defrag_thold      = params.defrag_thold;
     cparams.offload_kqv       = !params.no_kv_offload;
 
     cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
index 3e21579b005459e9717af264423b2222ae987612..25003df2600d18610c854595aca457eec2ecce72 100644 (file)
@@ -75,6 +75,7 @@ struct gpt_params {
     float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
     float   yarn_beta_slow        = 1.0f;  // YaRN high correction dim
     int32_t yarn_orig_ctx         = 0;     // YaRN original context length
+    float   defrag_thold          = -1.0f; // KV cache defragmentation threshold
     int32_t rope_scaling_type     = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     ggml_numa_strategy numa       = GGML_NUMA_STRATEGY_DISABLED;
 
index 47de67a93047f6778778d56b845e4527ab70f94a..2cbc9e1fa89ed1688d8e3e199e957282373bc0cb 100644 (file)
@@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
 
         llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
         llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-        llama_kv_cache_defrag (ctx);
+      //llama_kv_cache_defrag (ctx);
         llama_kv_cache_update (ctx);
 
         n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
@@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
 
             llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
             llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-            llama_kv_cache_defrag (ctx);
+          //llama_kv_cache_defrag (ctx);
             llama_kv_cache_update (ctx);
 
             n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
index 80dc4d166383e1ca649bc8f88f826c5b74849270..6729bb99c91fd0df8085334cadb77ae044c59ac0 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1641,6 +1641,7 @@ struct llama_cparams {
     float yarn_attn_factor;
     float yarn_beta_fast;
     float yarn_beta_slow;
+    float defrag_thold;
 
     bool mul_mat_q;
     bool offload_kqv;
@@ -5117,16 +5118,16 @@ struct llm_build_context {
     struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
-        for (int i = 0; i < n_kv; ++i) {
-            const int id = ids[i];
+        for (uint32_t i = 0; i < ids.size(); ++i) {
+            const uint32_t id = ids[i];
 
-            if (i == id || id == n_kv) {
+            if (i == id || id == ids.size()) {
                 continue;
             }
 
-            int nm = 1;
+            uint32_t nm = 1;
 
-            while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
+            while (i + nm < ids.size() && ids[i + nm] == id + nm) {
                 nm++;
             }
 
@@ -5158,6 +5159,8 @@ struct llm_build_context {
             i += nm - 1;
         }
 
+        //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
+
         return gf;
     }
 
@@ -7938,6 +7941,8 @@ static int llama_decode_internal(
         batch.seq_id = seq_id_arr.data();
     }
 
+    llama_kv_cache_update(&lctx);
+
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
     if (kv_self.head > kv_self.used + 2*n_tokens) {
@@ -7956,8 +7961,6 @@ static int llama_decode_internal(
 
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
-    llama_kv_cache_update(&lctx);
-
     ggml_backend_sched_reset(lctx.sched);
     ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
@@ -8007,6 +8010,18 @@ static int llama_decode_internal(
         }
     }
 
+    // decide if we need to defrag the kv cache
+    if (cparams.defrag_thold >= 0.0f) {
+        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
+
+        // queue defragmentation for next llama_kv_cache_update
+        if (fragmentation > cparams.defrag_thold) {
+            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+
+            llama_kv_cache_defrag(kv_self);
+        }
+    }
+
 #ifdef GGML_PERF
     // print timing information per ggml operation (for debugging purposes)
     // requires GGML_PERF to be defined
@@ -8098,12 +8113,16 @@ static int llama_decode_internal(
 static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
 
+    const auto & hparams = lctx.model.hparams;
+
+    const uint32_t n_layer = hparams.n_layer;
+
     const uint32_t n_kv   = llama_kv_cache_cell_max(kv_self);
     const uint32_t n_used = kv_self.used;
 
     assert(n_used <= n_kv);
 
-    const int64_t t_start = ggml_time_us();
+    //const int64_t t_start = ggml_time_us();
 
     // number of cells moved
     uint32_t n_moves = 0;
@@ -8127,15 +8146,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 
         // found a hole - fill it with data from the end of the cache
 
-        // determine the size of the hole
         uint32_t nh = 1;
+
+        // determine the size of the hole
         while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
             nh++;
         }
 
-        // starting from the end, find nh non-empty cells
+        // each move requires 6*n_layer tensors (see build_defrag)
+        //   - source view, destination view, copy operation
+        //   - x2 for keys and values
+        //
+        if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
+            // the graph is too big, we cannot move more cells
+            break;
+        }
+
         uint32_t nf = 0;
         uint32_t is = n_kv - 1;
+
+        // starting from the end, find nh non-empty cells
         for (; is > i0; --is) {
             const auto & cell1 = kv_self.cells[is];
 
@@ -8156,11 +8186,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 
         nf = 0;
 
+        uint32_t i1 = is;
+
+        // are we moving a continuous block of memory?
+        bool cont = false;
+
         // go back and move the nf cells to the hole
-        for (uint32_t i1 = is; i1 < n_kv; ++i1) {
-            const auto & cell1 = kv_self.cells[i1];
+        for (; i1 < n_kv; ++i1) {
+            auto & cell1 = kv_self.cells[i1];
 
             if (cell1.is_empty() || ids[i1] != n_kv) {
+                cont = false;
                 continue;
             }
 
@@ -8170,11 +8206,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
             // move the cell meta data
             kv_self.cells[i0 + nf] = cell1;
 
-            n_moves++;
+            // clear the old cell and move the head there
+            cell1 = llama_kv_cell();
+            kv_self.head = n_used;
+
+            if (!cont) {
+                n_moves++;
+                cont = true;
+            }
+
             nf++;
+
+            if (nf == nh) {
+                break;
+            }
         }
 
-        LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
+        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
 
         i0 += nh - 1;
     }
@@ -8183,15 +8231,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
         return;
     }
 
-    LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
-
-    kv_self.head = n_used;
-    kv_self.used = n_used;
+    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
 
-    // zero the rest of the cells
-    for (uint32_t i = n_used; i < n_kv; ++i) {
-        kv_self.cells[i] = llama_kv_cell();
-    }
+    //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
 
 #if 0
     // CPU defrag
@@ -8203,9 +8245,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // likely not worth the effort, as we have ggml_graph based defrag
     //
 
-    const auto & hparams = lctx.model.hparams;
-
-    const uint32_t n_layer      = hparams.n_layer;
     const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
     const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
 
@@ -8274,9 +8313,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
 #endif
 
-    const int64_t t_end = ggml_time_us();
+    //const int64_t t_end = ggml_time_us();
 
-    LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
+    //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
 }
 
 static void llama_kv_cache_update_internal(struct llama_context & lctx) {
@@ -11670,6 +11709,7 @@ struct llama_context_params llama_context_default_params() {
         /*.yarn_beta_fast              =*/ 32.0f,
         /*.yarn_beta_slow              =*/ 1.0f,
         /*.yarn_orig_ctx               =*/ 0,
+        /*.defrag_thold                =*/ -1.0f,
         /*.cb_eval                     =*/ nullptr,
         /*.cb_eval_user_data           =*/ nullptr,
         /*.type_k                      =*/ GGML_TYPE_F16,
@@ -11834,6 +11874,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_attn_factor = params.yarn_attn_factor;
     cparams.yarn_beta_fast   = params.yarn_beta_fast;
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.defrag_thold     = params.defrag_thold;
     cparams.mul_mat_q        = params.mul_mat_q;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.do_pooling       = params.do_pooling;
@@ -12035,7 +12076,7 @@ struct llama_context * llama_new_context_with_model(
             }
 
             // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
+            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
 
             ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
 
diff --git a/llama.h b/llama.h
index 3ff77d5a8997d1f27ff753b45689104750938760..6041618080344944f424d5d817aef06bb76dbb48 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -245,6 +245,7 @@ extern "C" {
         float    yarn_beta_fast;   // YaRN low correction dim
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
+        float    defrag_thold;     // defragment the KV cache if holes/size > thold, < 0 disabled (default)
 
         ggml_backend_sched_eval_callback cb_eval;
         void * cb_eval_user_data;