]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : pad the cache size to 256 for performance (#17046)
authorGeorgi Gerganov <redacted>
Fri, 7 Nov 2025 18:03:25 +0000 (20:03 +0200)
committerGitHub <redacted>
Fri, 7 Nov 2025 18:03:25 +0000 (20:03 +0200)
* kv-cache : pad the size of the small SWA cache for performance

* context : pad the total context to 256

* cont : future-proof the swa pad

* server : adjust test params to new logic

include/llama.h
src/llama-context.cpp
src/llama-kv-cache-iswa.cpp
tools/server/tests/unit/test_speculative.py

index 98bed9d6150a069cad17af9a84b824fd59f8cfe1..aa9932afb844ba76593bbfc6e17494a258955ef7 100644 (file)
@@ -463,6 +463,7 @@ extern "C" {
 
     // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
     //       In some cases the requested values via llama_context_params may differ from the actual values used by the context
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_ctx_seq  (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
index 866514038e493cdfaaa58ecda67fa46029cc4d5e..e115fcd933f5315e5ff5372f8f816767a838f986 100644 (file)
@@ -114,10 +114,14 @@ llama_context::llama_context(
         }
     }
 
+    // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
+    cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
+
     if (cparams.kv_unified) {
         cparams.n_ctx_seq = cparams.n_ctx;
     } else {
         cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
+        cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
 
         if (cparams.n_ctx_seq == 0) {
             throw std::runtime_error("n_ctx_seq == 0");
index facba1d004012b503ebe22f35f022a4d5dfe1a8f..3a34102a23d08b724a2bdcc26d6ac07dd7f246eb 100644 (file)
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
+    // note: the SWA cache is always padded to 256 for performance
+    //       https://github.com/ggml-org/llama.cpp/issues/17037
+    uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
index 65952de8b8d4c51a8ba2539996a130f53d8811c3..d2f3fba5fe7a921baad8815406d8635cc13caf35 100644 (file)
@@ -77,10 +77,10 @@ def test_different_draft_min_draft_max():
 
 def test_slot_ctx_not_exceeded():
     global server
-    server.n_ctx = 64
+    server.n_ctx = 256
     server.start()
     res = server.make_request("POST", "/completion", data={
-        "prompt": "Hello " * 56,
+        "prompt": "Hello " * 248,
         "temperature": 0.0,
         "top_k": 1,
         "speculative.p_min": 0.0,
@@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded():
 
 def test_with_ctx_shift():
     global server
-    server.n_ctx = 64
+    server.n_ctx = 256
     server.enable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
-        "prompt": "Hello " * 56,
+        "prompt": "Hello " * 248,
         "temperature": 0.0,
         "top_k": 1,
-        "n_predict": 64,
+        "n_predict": 256,
         "speculative.p_min": 0.0,
     })
     assert res.status_code == 200
     assert len(res.body["content"]) > 0
-    assert res.body["tokens_predicted"] == 64
+    assert res.body["tokens_predicted"] == 256
     assert res.body["truncated"] == True