]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : remove KV cache defragmentation logic (#15473)
authorGeorgi Gerganov <redacted>
Fri, 22 Aug 2025 09:22:13 +0000 (12:22 +0300)
committerGitHub <redacted>
Fri, 22 Aug 2025 09:22:13 +0000 (12:22 +0300)
ggml-ci

16 files changed:
common/arg.cpp
common/common.cpp
common/common.h
examples/llama.vim
include/llama.h
scripts/compare-llama-bench.py
src/llama-context.cpp
src/llama-cparams.h
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-kv-cells.h
src/llama-memory.h
tools/llama-bench/README.md
tools/llama-bench/llama-bench.cpp
tools/server/README.md
tools/server/bench/bench.py

index 1227aeb2a391539908c721b596e6fa6b29f42f79..81c4005c5e7fc0a8adce4f8caaade19c73cc8ac5 100644 (file)
@@ -2254,9 +2254,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
     add_opt(common_arg(
         {"-dt", "--defrag-thold"}, "N",
-        string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
+        string_format("KV cache defragmentation threshold (DEPRECATED)"),
         [](common_params & params, const std::string & value) {
-            params.defrag_thold = std::stof(value);
+            GGML_UNUSED(params);
+            GGML_UNUSED(value);
+            LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
         }
     ).set_env("LLAMA_ARG_DEFRAG_THOLD"));
     add_opt(common_arg(
index decabcc2ed327ed8c9b07dea0b61a28e06ef4406..fdce1dcdec19b613efe0c73e8fcc26eb1d4c968e 100644 (file)
@@ -1152,7 +1152,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
     cparams.pooling_type      = params.pooling_type;
     cparams.attention_type    = params.attention_type;
-    cparams.defrag_thold      = params.defrag_thold;
     cparams.cb_eval           = params.cb_eval;
     cparams.cb_eval_user_data = params.cb_eval_user_data;
     cparams.offload_kqv       = !params.no_kv_offload;
index 614e41a2461e773734d63d53715096fe17b5f21e..390dda5e531bedb275504d72fa3751b292e9c303 100644 (file)
@@ -288,7 +288,6 @@ struct common_params {
     float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
     float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
     int32_t yarn_orig_ctx         =     0; // YaRN original context length
-    float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
 
     // offload params
     std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
index af3fd3935d765d4cac760dbcbe361bf1f2abcfaa..736802d365541583cb806d27e0a6bdcc9809104c 100644 (file)
@@ -17,7 +17,7 @@
 "
 " start the llama.cpp server with a FIM-compatible model. for example:
 "
-"   $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
+"   $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256
 "
 "   --batch-size [512, model max context]
 "
index 662e0971dff2fa71c44bcb75e9a4ef5d49a1fb44..c5622cc16b4c2181bc1dd2aad07bfa409add0681 100644 (file)
@@ -312,7 +312,7 @@ extern "C" {
         float    yarn_beta_fast;   // YaRN low correction dim
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
-        float    defrag_thold;     // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
+        float    defrag_thold;     // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
 
         ggml_backend_sched_eval_callback cb_eval;
         void * cb_eval_user_data;
index 8366f89a08076683f8a518c7a660d19a399cb5fa..0141e0a350dc9fffcd37ea0f80300126e2ed6927 100755 (executable)
@@ -28,7 +28,6 @@ LLAMA_BENCH_DB_FIELDS = [
     "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
     "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
     "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
-    "defrag_thold",
     "use_mmap",     "embeddings",   "no_op_offload",  "n_prompt",   "n_gen",        "n_depth",
     "test_time",    "avg_ns",       "stddev_ns",      "avg_ts",     "stddev_ts",
 ]
index e8e8b3450a5d2b5bd59c0ab232e260031758c367..18cf25079d283e5ba9da357acd7b009205d99328 100644 (file)
@@ -39,7 +39,6 @@ llama_context::llama_context(
     cparams.yarn_attn_factor = params.yarn_attn_factor;
     cparams.yarn_beta_fast   = params.yarn_beta_fast;
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
-    cparams.defrag_thold     = params.defrag_thold;
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.flash_attn       = params.flash_attn;
@@ -978,7 +977,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
     bool did_optimize = false;
 
-    // handle any pending defrags/shifts
+    // handle any pending shifts/copies
     memory_update(false);
 
     llama_memory_context_ptr mctx;
index 38750affc500b74504f53bf65320332ddcef0d09..dbbaba9f6274cb0a5a96113e1f42396344c897e4 100644 (file)
@@ -24,7 +24,6 @@ struct llama_cparams {
     float yarn_attn_factor;
     float yarn_beta_fast;
     float yarn_beta_slow;
-    float defrag_thold;
 
     bool embeddings;
     bool causal_attn;
index bb490cf9e82a228b6b64b94213014e4a033b90ed..70ddd5f4b952cda367ab6b18a9abee8863533556 100644 (file)
@@ -525,39 +525,11 @@ llama_memory_context_ptr llama_kv_cache::init_full() {
 }
 
 llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
-    bool do_shift = get_has_shift();
-
-    defrag_info dinfo;
-
-    // see if we need to defrag
-    if (n_stream == 1) {
-        // note : for now do not consider defrag for n_stream > 1
-        const auto & cells = v_cells[seq_to_stream[0]];
-
-        bool do_defrag = optimize;
-
-        const auto thold = lctx->get_cparams().defrag_thold;
-
-        if (!do_defrag && thold > 0.0f) {
-            const auto n_kv = cells.used_max_p1();
-
-            // - do not defrag small contexts (i.e. < 2048 tokens)
-            // - count the padding towards the number of used tokens
-            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
-
-            if (fragmentation > thold) {
-                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-                do_defrag = true;
-            }
-        }
+    GGML_UNUSED(optimize);
 
-        if (do_defrag) {
-            dinfo = defrag_prepare(lctx->graph_max_nodes());
-        }
-    }
+    bool do_shift = get_has_shift();
 
-    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
+    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
 }
 
 llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -629,7 +601,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
     return res;
 }
 
-bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
+bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
     bool updated = false;
 
     auto * sched = lctx->get_sched();
@@ -699,53 +671,6 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_in
         }
     }
 
-    if (!dinfo.empty()) {
-        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
-
-        // note: for now do not consider defrag for n_stream > 1
-        auto & cells = v_cells[seq_to_stream[0]];
-        auto & head  = v_heads[seq_to_stream[0]];
-
-        // apply moves:
-        {
-            const auto n_kv = dinfo.ids.size();
-
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                assert(dinfo.ids[i] <= n_kv);
-
-                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
-                    continue;
-                }
-
-                cells.mv(i, dinfo.ids[i]);
-            }
-
-            // reset the head so we can find the first free slot during the next ubatch
-            head = 0;
-        }
-
-        ggml_backend_sched_reset(sched);
-
-        auto * res = lctx->get_gf_res_reserve();
-
-        res->reset();
-
-        auto * gf = build_graph_defrag(res, lctx, dinfo);
-        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
-            return updated;
-        }
-
-        res->set_inputs(nullptr);
-
-        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
-            return updated;
-        }
-
-        updated = true;
-    }
-
     return updated;
 }
 
@@ -1525,283 +1450,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
     return gf;
 }
 
-ggml_cgraph * llama_kv_cache::build_graph_defrag(
-         llm_graph_result * res,
-            llama_context * lctx,
-        const defrag_info & dinfo) const {
-    auto * ctx = res->get_ctx();
-    auto * gf  = res->get_gf();
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const auto & ids = dinfo.ids;
-
-    const auto & cparams = lctx->get_cparams();
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = size;
-
-    std::vector<uint8_t> buf_k;
-    std::vector<uint8_t> buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(v_l[il]->type);
-        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*i));
-
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*id));
-
-            ggml_tensor * view_v_src;
-            ggml_tensor * view_v_dst;
-
-            if (cparams.flash_attn) {
-                // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*id));
-            } else {
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, id));
-            }
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
-        }
-
-        i += nm - 1;
-    }
-
-    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-#endif
-
-    return gf;
-}
-
-llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const uint32_t n_layer = layers.size();
-
-    const uint32_t n_kv   = cells.used_max_p1();
-    const uint32_t n_used = cells.get_used();
-
-    assert(n_used <= n_kv);
-
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
-
-    // determine which KV cells to move where
-    defrag_info res;
-    auto & ids = res.ids;
-
-    ids.resize(n_kv, n_kv);
-
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        if (!cells.is_empty(i0)) {
-            ids[i0] = i0;
-
-            continue;
-        }
-
-        // found a hole - fill it with data from the end of the cache
-
-        uint32_t nh = 1;
-
-        // determine the size of the hole
-        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
-            nh++;
-        }
-
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
-
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            if (cells.is_empty(is) || ids[is] != n_kv) {
-                continue;
-            }
-
-            // non-empty cell which is not yet moved
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        // this can only happen if `n_used` is not accurate, which would be a bug
-        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
-
-        nf = 0;
-
-        uint32_t i1 = is;
-
-        // are we moving a continuous block of memory?
-        bool cont = false;
-
-        // should we stop searching for the next move?
-        bool stop = false;
-
-        // go back and move the nf cells to the hole
-        for (; i1 < n_kv; ++i1) {
-            if (cells.is_empty(i1) || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
-                cont = false;
-                continue;
-            }
-
-            // this cell goes to (i0 + nf)
-            ids[i1] = i0 + nf;
-
-            if (!cont) {
-                n_moves++;
-                cont = true;
-            }
-
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
-        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
-
-        i0 += nh - 1;
-    }
-
-    if (n_moves == 0) {
-        return {};
-    }
-
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
-
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
-
-    return res;
-}
-
 bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
     assert(p0 >= 0 && p1 >= 0);
 
@@ -2300,9 +1948,8 @@ llama_kv_cache_context::llama_kv_cache_context(
         llama_kv_cache * kv,
         llama_context * lctx,
         bool do_shift,
-        defrag_info dinfo,
-        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
-    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
+        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
+    if (!do_shift && this->sc_info.empty()) {
         status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
 }
@@ -2330,7 +1977,7 @@ bool llama_kv_cache_context::apply() {
 
     // no ubatches -> this is a KV cache update
     if (ubatches.empty()) {
-        kv->update(lctx, do_shift, dinfo, sc_info);
+        kv->update(lctx, do_shift, sc_info);
 
         return true;
     }
index 5ca618e1b82e1efd01e72cfd4a0b9f79452450d3..297a0973dd467a641d8a5f945a919018bdb85e15 100644 (file)
@@ -24,17 +24,6 @@ public:
     // this callback is used to filter out layers that should not be included in the cache
     using layer_filter_cb = std::function<bool(int32_t il)>;
 
-    struct defrag_info {
-        bool empty() const {
-            return ids.empty();
-        }
-
-        // contains information about which cell moves where:
-        //  - cell i moves to ids[i]
-        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
-        std::vector<uint32_t> ids;
-    };
-
     struct stream_copy_info {
         bool empty() const {
             assert(ssrc.size() == sdst.size());
@@ -173,7 +162,7 @@ public:
     // return empty vector on failure
     slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
 
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
+    bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
 
     // find a slot of kv cells that can hold the ubatch
     // if cont == true, then the slot must be continuous
@@ -254,9 +243,6 @@ private:
     // model layer id -> KV cache layer id
     std::unordered_map<int32_t, int32_t> map_layer_ids;
 
-    // return non-empty vector if cells have been moved
-    defrag_info defrag_prepare(int32_t n_max_nodes) const;
-
     size_t total_size() const;
 
     size_t size_k_bytes() const;
@@ -277,11 +263,6 @@ private:
                llm_graph_result * res,
                   llama_context * lctx) const;
 
-    ggml_cgraph * build_graph_defrag(
-               llm_graph_result * res,
-                  llama_context * lctx,
-              const defrag_info & dinfo) const;
-
     struct cell_ranges_t {
         uint32_t strm;
 
@@ -299,7 +280,6 @@ class llama_kv_cache_context : public llama_memory_context_i {
 public:
     // some shorthands
     using slot_info_vec_t  = llama_kv_cache::slot_info_vec_t;
-    using defrag_info      = llama_kv_cache::defrag_info;
     using stream_copy_info = llama_kv_cache::stream_copy_info;
 
     // used for errors
@@ -314,7 +294,6 @@ public:
             llama_kv_cache * kv,
             llama_context * lctx,
             bool do_shift,
-            defrag_info dinfo,
             stream_copy_info sc_info);
 
     // used to create a batch procesing context from a batch
@@ -374,8 +353,6 @@ private:
 
     bool do_shift = false;
 
-    defrag_info dinfo;
-
     stream_copy_info sc_info;
 
     //
index 2651e30331fd692a38b1caab1ff3dabd837f36a0..8f6bf01456c8fb734230b00170f6d2b84b869d89 100644 (file)
@@ -77,24 +77,24 @@ public:
     }
 
     // move cell isrc to idst (used during defrag)
-    void mv(uint32_t isrc, uint32_t idst) {
-        assert(isrc < pos.size());
-        assert(idst < pos.size());
+    //void mv(uint32_t isrc, uint32_t idst) {
+    //    assert(isrc < pos.size());
+    //    assert(idst < pos.size());
 
-        assert(pos[idst] == -1);
-        assert(pos[isrc] != -1);
+    //    assert(pos[idst] == -1);
+    //    assert(pos[isrc] != -1);
 
-        pos  [idst] = pos  [isrc];
-        shift[idst] = shift[isrc];
-        seq  [idst] = seq  [isrc];
+    //    pos  [idst] = pos  [isrc];
+    //    shift[idst] = shift[isrc];
+    //    seq  [idst] = seq  [isrc];
 
-        pos  [isrc] = -1;
-        shift[isrc] =  0;
-        seq  [isrc].reset();
+    //    pos  [isrc] = -1;
+    //    shift[isrc] =  0;
+    //    seq  [isrc].reset();
 
-        used.erase (isrc);
-        used.insert(idst);
-    }
+    //    used.erase (isrc);
+    //    used.insert(idst);
+    //}
 
     // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
     llama_kv_cells cp(uint32_t i, uint32_t n) const {
index 42a7145c2f38703aa61acb8680c33c667034cdce..94d858bccc2e0393fd50767246ebd5d12e203a7f 100644 (file)
@@ -77,7 +77,7 @@ struct llama_memory_i {
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_context_ptr init_full() = 0;
 
-    // prepare for any pending memory updates, such as shifts, defrags, etc.
+    // prepare for any pending memory updates, such as shifts, copies, etc.
     // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
     virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
 
index 31a27308743469a00190e6e184b5978dc781a5ab..bf7fd29c8c55f48a45e7ecbc97e3080215218652 100644 (file)
@@ -43,7 +43,6 @@ test parameters:
   -ub, --ubatch-size <n>                    (default: 512)
   -ctk, --cache-type-k <t>                  (default: f16)
   -ctv, --cache-type-v <t>                  (default: f16)
-  -dt, --defrag-thold <f>                   (default: -1)
   -t, --threads <n>                         (default: system dependent)
   -C, --cpu-mask <hex,hex>                  (default: 0x0)
   --cpu-strict <0|1>                        (default: 0)
index 10b48c5568612f971e6aa2487425eb3528206db9..9378706a12a7c9c31fe656137bfbc7a8fa59f7f6 100644 (file)
@@ -245,7 +245,6 @@ struct cmd_params {
     std::vector<int>                 n_ubatch;
     std::vector<ggml_type>           type_k;
     std::vector<ggml_type>           type_v;
-    std::vector<float>               defrag_thold;
     std::vector<int>                 n_threads;
     std::vector<std::string>         cpu_mask;
     std::vector<bool>                cpu_strict;
@@ -282,7 +281,6 @@ static const cmd_params cmd_params_defaults = {
     /* n_ubatch             */ { 512 },
     /* type_k               */ { GGML_TYPE_F16 },
     /* type_v               */ { GGML_TYPE_F16 },
-    /* defrag_thold         */ { -1.0f },
     /* n_threads            */ { cpu_get_num_math() },
     /* cpu_mask             */ { "0x0" },
     /* cpu_strict           */ { false },
@@ -346,8 +344,6 @@ static void print_usage(int /* argc */, char ** argv) {
            join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
     printf("  -ctv, --cache-type-v <t>                  (default: %s)\n",
            join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
-    printf("  -dt, --defrag-thold <f>                   (default: %s)\n",
-           join(cmd_params_defaults.defrag_thold, ",").c_str());
     printf("  -t, --threads <n>                         (default: %s)\n",
            join(cmd_params_defaults.n_threads, ",").c_str());
     printf("  -C, --cpu-mask <hex,hex>                  (default: %s)\n",
@@ -533,13 +529,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
                     break;
                 }
                 params.type_v.insert(params.type_v.end(), types.begin(), types.end());
-            } else if (arg == "-dt" || arg == "--defrag-thold") {
-                if (++i >= argc) {
-                    invalid_param = true;
-                    break;
-                }
-                auto p = string_split<float>(argv[i], split_delim);
-                params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end());
             } else if (arg == "-t" || arg == "--threads") {
                 if (++i >= argc) {
                     invalid_param = true;
@@ -849,9 +838,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.type_v.empty()) {
         params.type_v = cmd_params_defaults.type_v;
     }
-    if (params.defrag_thold.empty()) {
-        params.defrag_thold = cmd_params_defaults.defrag_thold;
-    }
     if (params.n_gpu_layers.empty()) {
         params.n_gpu_layers = cmd_params_defaults.n_gpu_layers;
     }
@@ -910,7 +896,6 @@ struct cmd_params_instance {
     int                n_ubatch;
     ggml_type          type_k;
     ggml_type          type_v;
-    float              defrag_thold;
     int                n_threads;
     std::string        cpu_mask;
     bool               cpu_strict;
@@ -1007,7 +992,6 @@ struct cmd_params_instance {
         cparams.n_ubatch     = n_ubatch;
         cparams.type_k       = type_k;
         cparams.type_v       = type_v;
-        cparams.defrag_thold = defrag_thold;
         cparams.offload_kqv  = !no_kv_offload;
         cparams.flash_attn   = flash_attn;
         cparams.embeddings   = embeddings;
@@ -1037,7 +1021,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
     for (const auto & nub : params.n_ubatch)
     for (const auto & tk : params.type_k)
     for (const auto & tv : params.type_v)
-    for (const auto & defrag_thold : params.defrag_thold)
     for (const auto & nkvo : params.no_kv_offload)
     for (const auto & fa : params.flash_attn)
     for (const auto & nt : params.n_threads)
@@ -1058,7 +1041,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
-                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -1091,7 +1073,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
-                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -1124,7 +1105,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
-                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -1166,7 +1146,6 @@ struct test {
     int                      poll;
     ggml_type                type_k;
     ggml_type                type_v;
-    float                    defrag_thold;
     int                      n_gpu_layers;
     llama_split_mode         split_mode;
     int                      main_gpu;
@@ -1201,7 +1180,6 @@ struct test {
         poll           = inst.poll;
         type_k         = inst.type_k;
         type_v         = inst.type_v;
-        defrag_thold   = inst.defrag_thold;
         n_gpu_layers   = inst.n_gpu_layers;
         split_mode     = inst.split_mode;
         main_gpu       = inst.main_gpu;
@@ -1257,7 +1235,6 @@ struct test {
             "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
             "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
             "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
-            "defrag_thold",
             "use_mmap",     "embeddings",   "no_op_offload",   "n_prompt",       "n_gen",      "n_depth",      "test_time",
             "avg_ns",       "stddev_ns",    "avg_ts",         "stddev_ts",
         };
@@ -1277,7 +1254,7 @@ struct test {
             field == "use_mmap" || field == "embeddings") {
             return BOOL;
         }
-        if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") {
+        if (field == "avg_ts" || field == "stddev_ts") {
             return FLOAT;
         }
         return STRING;
@@ -1344,7 +1321,6 @@ struct test {
                                             std::to_string(flash_attn),
                                             tensor_split_str,
                                             tensor_buft_overrides_str,
-                                            std::to_string(defrag_thold),
                                             std::to_string(use_mmap),
                                             std::to_string(embeddings),
                                             std::to_string(no_op_offload),
@@ -1611,9 +1587,6 @@ struct markdown_printer : public printer {
         if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
             fields.emplace_back("type_v");
         }
-        if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) {
-            fields.emplace_back("defrag_thold");
-        }
         if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
             fields.emplace_back("main_gpu");
         }
index 86844225ff309cd9397b395a9ec882e49bf2d6e1..baf3730add67c113e4164c87015f7f0e2743997b 100644 (file)
@@ -66,7 +66,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
 | `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
-| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: 0.1, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
+| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
 | `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
 | `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
 | `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)<br/>(env: LLAMA_ARG_NO_MMAP) |
index 5cc6f92ab6c53ee6eae18fe3eee02281fd4095e9..0c57a2df04a60d3d914b6a443dea87ed9a14b0b1 100644 (file)
@@ -274,7 +274,6 @@ def start_server_background(args):
     server_args.extend(['--batch-size', args.batch_size])
     server_args.extend(['--ubatch-size', args.ubatch_size])
     server_args.extend(['--n-predict', args.max_tokens * 2])
-    server_args.extend(['--defrag-thold', "0.1"])
     server_args.append('--cont-batching')
     server_args.append('--metrics')
     server_args.append('--flash-attn')