]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : remove KV cache size padding (#16812)
authorGeorgi Gerganov <redacted>
Tue, 28 Oct 2025 18:19:44 +0000 (20:19 +0200)
committerGitHub <redacted>
Tue, 28 Oct 2025 18:19:44 +0000 (20:19 +0200)
* memory : remove KV cache size padding

* cont : restore padding for n_kv tensor shape

* server : use slot context size instead of training context size

* server : simplify context limit logic

src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-model.cpp
src/llama-model.h
tools/server/server.cpp
tools/server/tests/unit/test_ctx_shift.py

index add74391f0c47bb7ed6cd4179526f5ecb956db80..6d5dd6051e782e386b8b3ee057d442f160828e27 100644 (file)
@@ -961,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const {
 uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
     uint32_t result = 0;
 
+    // pad the n_kv value so that the graph remains constant across batches and can be reused
+    // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
+    const uint32_t n_pad_cur = std::max(n_pad, 256u);
+
     for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
         const auto & cells = v_cells[sinfo.strm[s]];
 
-        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
+        result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
     }
 
     return result;
@@ -2014,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
 void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_pos_bucket(dst, ubatch);
 }
-
-uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
index 150e282596255d5e8121bf4b781523d9b59dbf79..bf7821c07ca8f5e6605ad117fcb2bc2178daa50c 100644 (file)
@@ -19,8 +19,6 @@ struct llama_context;
 
 class llama_kv_cache : public llama_memory_i {
 public:
-    static uint32_t get_padding(const llama_cparams & cparams);
-
     struct stream_copy_info {
         bool empty() const {
             assert(ssrc.size() == sdst.size());
index bb83a04e9605531d2dd0ffa36ce942d7d1a60fbf..ea6f59ed482bb24f5a9ccb0d2435164f3b382441 100644 (file)
@@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
     }
 };
 
-llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
+llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
     llama_memory_i * res;
 
     switch (arch) {
@@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         };
                     }
 
-                    const auto padding = llama_kv_cache::get_padding(cparams);
-
-                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
-
                     res = new llama_memory_hybrid(
                         /* model             */ *this,
                         /* attn_type_k       */ params.type_k,
                         /* attn_type_v       */ params.type_v,
                         /* attn_v_trans      */ !cparams.flash_attn,
                         /* attn_kv_size      */ cparams.n_ctx,
-                        /* attn_n_pad        */ padding,
+                        /* attn_n_pad        */ 1,
                         /* attn_n_swa        */ hparams.n_swa,
                         /* attn_swa_type     */ hparams.swa_type,
                         /* recurrent_type_k  */ GGML_TYPE_F32,
@@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* filter_attn       */ std::move(filter_attn),
                         /* filter_recr       */ std::move(filter_recr));
                 } else {
-                    const auto padding = llama_kv_cache::get_padding(cparams);
-
                     uint32_t n_ctx_per_stream = cparams.n_ctx;
 
                     if (!cparams.kv_unified) {
                         n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
-                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
-
-                        cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
-                    } else {
-                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
-
-                        cparams.n_ctx = n_ctx_per_stream;
                     }
 
-                    LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
-
                     llama_memory_i::layer_reuse_cb reuse = nullptr;
 
                     if (arch == LLM_ARCH_GEMMA3N) {
@@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
-                                padding,
+                                1,
                                 nullptr,
                                 reuse);
                     } else {
@@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.kv_unified,
                                 n_ctx_per_stream,
                                 cparams.n_seq_max,
-                                padding,
+                                1,
                                 hparams.n_swa,
                                 hparams.swa_type,
                                 nullptr,
index 248f854101cd740491e051dce16a3ab1d22e222f..1ab1cf7f8e94d69e926bf6cf8e88d0379db5eb4c 100644 (file)
@@ -500,9 +500,8 @@ struct llama_model {
 
     ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
 
-    // note: can mutate `cparams`
     // TODO: move this to new llm_arch_model_i interface
-    llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
+    llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
 
     // TODO: move this to new llm_arch_model_i interface
     ggml_cgraph * build_graph(const llm_graph_params & params) const;
index 4124bffa40f8565098dd019bef2a2b024cbed96a..cb794ab647eba1ea4997f989959b8ca9a10e180c 100644 (file)
@@ -2866,10 +2866,12 @@ struct server_context {
 
         // if context shifting is disabled, make sure that we don't run out of context
         if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
+            slot.truncated      = true;
             slot.stop           = STOP_TYPE_LIMIT;
             slot.has_next_token = false;
 
-            SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
+            SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
+                    slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
         }
 
         // check the limits
@@ -2929,16 +2931,6 @@ struct server_context {
             }
         }
 
-        // if context shift is disabled, we stop when it reaches the context limit
-        if (slot.n_past >= slot.n_ctx) {
-            slot.truncated      = true;
-            slot.stop           = STOP_TYPE_LIMIT;
-            slot.has_next_token = false;
-
-            SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
-                    slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
-        }
-
         if (llama_vocab_is_eog(vocab, result.tok)) {
             slot.stop           = STOP_TYPE_EOS;
             slot.has_next_token = false;
@@ -2946,19 +2938,6 @@ struct server_context {
             SLT_DBG(slot, "%s", "stopped by EOS\n");
         }
 
-        const auto n_ctx_train = llama_model_n_ctx_train(model);
-
-        if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
-            slot.truncated      = true;
-            slot.stop           = STOP_TYPE_LIMIT;
-            slot.has_next_token = false; // stop prediction
-
-            SLT_WRN(slot,
-                    "n_predict (%d) is set for infinite generation. "
-                    "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
-                    slot.task->params.n_predict, n_ctx_train);
-        }
-
         SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
 
         return slot.has_next_token; // continue
index 4adbbde64f5947239a2b11b2ceb19b78681b3f71..7b047b7b3b74d3ed6476f7702cd6d0117f5eed9c 100644 (file)
@@ -45,7 +45,7 @@ def test_ctx_shift_enabled():
 
 @pytest.mark.parametrize("n_predict,n_token_output,truncated", [
     (64, 64, False),
-    (-1, 120, True),
+    (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
 ])
 def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
     global server