]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : support unified cache across slots (#16736)
authorGeorgi Gerganov <redacted>
Sun, 2 Nov 2025 16:14:04 +0000 (18:14 +0200)
committerGitHub <redacted>
Sun, 2 Nov 2025 16:14:04 +0000 (18:14 +0200)
* server : support unified context across slots

* cont : fix speculative decoding initialization

* context : fix n_ctx_per_seq computation

* server : purge slots one by one

* tests : add unified cache server tests

* llama : update per-seq context computation

* test-thread-safety : handle tiny training context of the input model

* server : fix server_tokens clear()

* server : use 4 slots + unified KV by default

* llama : add note about context size queries

* cont : update todos [no ci]

* context : do not cap the size of the context

* tests : adjust parameters to be CI friendlier

* context : add warning

12 files changed:
include/llama.h
src/llama-context.cpp
src/llama-context.h
src/llama-cparams.h
src/llama-model.cpp
tests/test-thread-safety.cpp
tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py
tools/server/tests/unit/test_completion.py
tools/server/tests/unit/test_infill.py
tools/server/tests/utils.py
tools/server/utils.hpp

index 05baa43da08d53a45c18cc00962e3f38b45eb2c3..98bed9d6150a069cad17af9a84b824fd59f8cfe1 100644 (file)
@@ -461,7 +461,10 @@ extern "C" {
     LLAMA_API bool llama_supports_gpu_offload(void);
     LLAMA_API bool llama_supports_rpc        (void);
 
+    // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
+    //       In some cases the requested values via llama_context_params may differ from the actual values used by the context
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
+    LLAMA_API uint32_t llama_n_ctx_seq  (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
@@ -585,7 +588,7 @@ extern "C" {
     LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
 
     // Manually free a LoRA adapter
-    // Note: loaded adapters will be free when the associated model is deleted
+    // NOTE: loaded adapters will be free when the associated model is deleted
     LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
 
     // Get the invocation tokens if the current lora is an alora
index f6192a36e0ee5db58fd2e2399a972b8b0a2537fd..2b39366271ff9e95f68c6f726c7dedf92eb9d272 100644 (file)
@@ -112,11 +112,24 @@ llama_context::llama_context(
         }
     }
 
-    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
+    if (cparams.kv_unified) {
+        cparams.n_ctx_seq = cparams.n_ctx;
+    } else {
+        cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
+
+        if (cparams.n_ctx_seq == 0) {
+            throw std::runtime_error("n_ctx_seq == 0");
+        }
+
+        if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
+            cparams.n_ctx =  cparams.n_ctx_seq * cparams.n_seq_max;
+            LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
+        }
+    }
 
     LLAMA_LOG_INFO("%s: n_seq_max     = %u\n",   __func__, cparams.n_seq_max);
     LLAMA_LOG_INFO("%s: n_ctx         = %u\n",   __func__, cparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n",   __func__, n_ctx_per_seq);
+    LLAMA_LOG_INFO("%s: n_ctx_seq     = %u\n",   __func__, cparams.n_ctx_seq);
     LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
@@ -125,14 +138,14 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
 
-    if (n_ctx_per_seq < hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
+    if (cparams.n_ctx_seq < hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
+                __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
     }
 
-    if (n_ctx_per_seq > hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
+    if (cparams.n_ctx_seq > hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
+                __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
     }
 
     if (!hparams.vocab_only) {
@@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
     return cparams.n_ctx;
 }
 
-uint32_t llama_context::n_ctx_per_seq() const {
-    return cparams.n_ctx / cparams.n_seq_max;
+uint32_t llama_context::n_ctx_seq() const {
+    return cparams.n_ctx_seq;
 }
 
 uint32_t llama_context::n_batch() const {
@@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
     return ctx->n_ctx();
 }
 
+uint32_t llama_n_ctx_seq(const llama_context * ctx) {
+    return ctx->n_ctx_seq();
+}
+
 uint32_t llama_n_batch(const llama_context * ctx) {
     return ctx->n_batch();
 }
index ed6d82cb396f99542abcadc339fd7536df7702c2..20cbd78955412c3bdec9e462bd23e02d7347b99e 100644 (file)
@@ -43,11 +43,11 @@ struct llama_context {
 
     ggml_backend_sched_t get_sched() const;
 
-    uint32_t n_ctx()         const;
-    uint32_t n_ctx_per_seq() const;
-    uint32_t n_batch()       const;
-    uint32_t n_ubatch()      const;
-    uint32_t n_seq_max()     const;
+    uint32_t n_ctx()     const;
+    uint32_t n_ctx_seq() const;
+    uint32_t n_batch()   const;
+    uint32_t n_ubatch()  const;
+    uint32_t n_seq_max() const;
 
     uint32_t n_threads()       const;
     uint32_t n_threads_batch() const;
index eae7b839f4857da2df56be6cd8d4d9a5fe362b7a..fcef8fa97603868db0a80143ff0d1616743e1eb0 100644 (file)
@@ -8,6 +8,7 @@
 
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
+    uint32_t n_ctx_seq;       // context for a single sequence
     uint32_t n_batch;
     uint32_t n_ubatch;
     uint32_t n_seq_max;
index 04239181c77657334c65729c65a9acdb210257c9..896725466ce24649e024e3136715a3bf8affa119 100644 (file)
@@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
 }
 
 ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
-    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
+    const uint32_t n_ctx_seq = cparams.n_ctx_seq;
 
     // choose long/short freq factors based on the context size
     if (layers[il].rope_freqs != nullptr) {
         return layers[il].rope_freqs;
     }
 
-    if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
+    if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
         return layers[il].rope_long;
     }
 
@@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* filter_attn       */ std::move(filter_attn),
                         /* filter_recr       */ std::move(filter_recr));
                 } else {
-                    uint32_t n_ctx_per_stream = cparams.n_ctx;
-
-                    if (!cparams.kv_unified) {
-                        n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
-                    }
-
                     llama_memory_i::layer_reuse_cb reuse = nullptr;
 
                     if (arch == LLM_ARCH_GEMMA3N) {
@@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.offload_kqv,
                                 params.swa_full,
                                 cparams.kv_unified,
-                                n_ctx_per_stream,
+                                cparams.n_ctx_seq,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
                                 1,
@@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
                                 cparams.kv_unified,
-                                n_ctx_per_stream,
+                                cparams.n_ctx_seq,
                                 cparams.n_seq_max,
                                 1,
                                 hparams.n_swa,
index e5158fb5062f02d08a38f405c23ed23c0ee1c8ae..bcb86c35e66526aedc79de231ac688846c89be8b 100644 (file)
@@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
                     }
 
                     batch = llama_batch_get_one(&token, 1);
-                    if (llama_decode(ctx.get(), batch)) {
+
+                    int ret = llama_decode(ctx.get(), batch);
+                    if (ret == 1 && i > 0) {
+                        LOG_INF("Context full, stopping generation.\n");
+                        break;
+                    }
+
+                    if (ret != 0) {
                         LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
                         failed.store(true);
                         return;
index 92d30664e41f458b218b5b20160a826b60164c7e..aa4981585200adb7bad00ae1db54ddbd72a735fc 100644 (file)
@@ -2407,7 +2407,7 @@ struct server_context {
 
             params_dft.devices      = params_base.speculative.devices;
             params_dft.model        = params_base.speculative.model;
-            params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
+            params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
             params_dft.n_parallel   = 1;
             params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2495,10 +2495,16 @@ struct server_context {
     }
 
     void init() {
-        const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
-
         SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
 
+        const int n_ctx_train = llama_model_n_ctx_train(model);
+
+        int n_ctx_slot = llama_n_ctx_seq(ctx);
+        if (n_ctx_slot > n_ctx_train) {
+            SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
+            n_ctx_slot = n_ctx_train;
+        }
+
         for (int i = 0; i < params_base.n_parallel; i++) {
             server_slot slot;
 
@@ -2527,7 +2533,7 @@ struct server_context {
                 }
             }
 
-            SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
+            SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
 
             slot.callback_on_release = [this](int) {
                 queue_tasks.pop_deferred_task();
@@ -2699,6 +2705,39 @@ struct server_context {
         return ret;
     }
 
+    // return true if at least one slot has been purged
+    // TODO: improve logic
+    //       - smarter decision which slot to purge (LRU or longest prompt?)
+    //       - move slot to level 2 cache instead of removing?
+    //       - instead of purging, try to store and resume later?
+    bool try_purge_idle_slots() {
+        bool res = false;
+
+        if (!params_base.kv_unified) {
+            return res;
+        }
+
+        for (auto & slot : slots) {
+            if (slot.is_processing()) {
+                continue;
+            }
+
+            if (slot.prompt.n_tokens() > 0) {
+                SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
+
+                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
+                slot.prompt.tokens.clear();
+
+                res = true;
+
+                // purge slots one by one
+                break;
+            }
+        }
+
+        return res;
+    }
+
     bool launch_slot_with_task(server_slot & slot, server_task && task) {
         slot.reset();
 
@@ -3635,9 +3674,10 @@ struct server_context {
         int32_t n_batch  = llama_n_batch(ctx);
         int32_t n_ubatch = llama_n_ubatch(ctx);
 
-        // next, batch any pending prompts without exceeding n_batch
-        float alora_scale = -1.0f;
+        float  alora_scale       = -1.0f;
         size_t alora_disabled_id = 0;
+
+        // next, batch any pending prompts without exceeding n_batch
         if (params_base.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
                 // check if we can batch this slot with the previous one
@@ -3914,8 +3954,11 @@ struct server_context {
 
                     // truncate any tokens that are beyond n_past for this slot
                     const llama_pos p0 = slot.prompt.tokens.pos_next();
+
+                    SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
+
                     if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
-                        SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
+                        SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
                         llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
 
                         // there is no common part left
@@ -3924,8 +3967,6 @@ struct server_context {
                         slot.prompt.tokens.clear();
                     }
 
-                    SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
-
                     // check if we should process the image
                     if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
                         // process the image
@@ -4126,6 +4167,8 @@ struct server_context {
                     std::string err;
 
                     if (n_batch == 1 && ret == 1) {
+                        // TODO: try to terminate only the largest active slot/sequence and continue with the rest
+                        //       need to remove the tokens from the current batch too
                         err = "Context size has been exceeded.";
                     }
 
@@ -4141,17 +4184,23 @@ struct server_context {
                     // TODO: handle ret == 2 (abort) when we start aborting
 
                     if (!err.empty()) {
-                        SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
+                        SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
+
                         for (auto & slot : slots) {
-                            send_error(slot, err);
-                            slot.release();
+                            if (slot.is_processing()) {
+                                send_error(slot, err);
+                                slot.release();
+                            }
                         }
+
                         break;
                     }
                 }
 
                 // retry with half the batch size to try to find a free slot in the KV cache
-                n_batch /= 2;
+                if (!try_purge_idle_slots()) {
+                    n_batch /= 2;
+                }
 
                 SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
 
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    // TODO: should we have a separate n_parallel parameter for the server?
+    //       https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
+    if (params.n_parallel == 1 && params.kv_unified == false) {
+        LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
+
+        params.n_parallel = 4;
+        params.kv_unified = true;
+    }
+
     common_init();
 
     // struct that contains llama context and inference
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
                 // Everything else, including multimodal completions.
                 inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
             }
-            const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
+            const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
             tasks.reserve(inputs.size());
             for (size_t i = 0; i < inputs.size(); i++) {
                 auto n_prompt_tokens = inputs[i].size();
index d56d3d5f178b80d37ced9c95f4db65a82e05b364..392e0efecdbbdf7f73a394022af7924972da6e4a 100644 (file)
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
 @pytest.mark.parametrize(
     "n_batch,batch_count,reuse_cache",
     [
-        (64, 15, False),
+        (64, 3, False),
         (64, 1, True),
     ]
 )
-def test_return_progresssss(n_batch, batch_count, reuse_cache):
+def test_return_progress(n_batch, batch_count, reuse_cache):
     global server
     server.n_batch = n_batch
-    server.n_ctx = 2048
+    server.n_ctx = 256
     server.n_slots = 1
     server.start()
     def make_cmpl_request():
         return server.make_stream_request("POST", "/chat/completions", data={
             "max_tokens": 10,
             "messages": [
-                {"role": "user", "content": "This is a test" * 100},
+                {"role": "user", "content": "This is a test" * 10},
             ],
             "stream": True,
             "return_progress": True,
index 00ba78cf67c0935c71ea5dcb3909fd1e2285e91b..3c0ce98973f4b8aa84506c7fad74534ba3aab387 100644 (file)
@@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
         # assert match_regex(re_content, res.body["content"])
 
 
+@pytest.mark.parametrize(
+    "n_ctx,n_slots,n_predict_vals,expected_success",
+    [
+        (256, 4, [80, 40, 80, 80], [True,  True,  True,  True]),
+        (256, 4, [70, 70, 70, 70], [False, False, False, False]),
+        (256, 4, [90, 90, 40, 90], [False, False, True,  False]),
+        (256, 4, [90, 90, 40, 75], [True,  True,  True,  True]),
+    ],
+)
+def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
+    global server
+    server.n_slots = n_slots
+    server.kv_unified = True
+    server.n_ctx = n_ctx
+    server.start()
+    prompt = "A"
+    tasks = []
+    for n_predict in n_predict_vals:
+        tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
+    results = parallel_function_calls(tasks)
+    for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
+        if expect_ok:
+            assert res.status_code == 200
+            assert "content" in res.body
+            if "timings" in res.body:
+                assert res.body["timings"]["predicted_n"] == n_predict
+        else:
+            assert res.status_code == 500
+            assert "content" not in res.body
+
+
 @pytest.mark.parametrize(
     "prompt,n_predict,response_fields",
     [
index 73dacdae812b8f835acd08489e7288d813ab9a77..cd1a391b4adbc78fdecfe4d1a02a24fa070f60c8 100644 (file)
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
         "input_suffix": "}\n",
     })
     assert res.status_code == 200
-    assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
+    assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
 
 
 def test_infill_with_input_extra():
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
         "input_suffix": "}\n",
     })
     assert res.status_code == 200
-    assert match_regex("(Dad|excited|park)+", res.body["content"])
+    assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
 
 
 @pytest.mark.parametrize("input_extra", [
index 4ba3d43c330442a2cfe305d790f12144a5b411a5..da703c4c51a15461334ecc9c8ca973274f42d6d7 100644 (file)
@@ -78,6 +78,7 @@ class ServerProcess:
     server_embeddings: bool | None = False
     server_reranking: bool | None = False
     server_metrics: bool | None = False
+    kv_unified: bool | None = False
     server_slots: bool | None = False
     pooling: str | None = None
     draft: int | None = None
@@ -159,6 +160,8 @@ class ServerProcess:
             server_args.append("--reranking")
         if self.server_metrics:
             server_args.append("--metrics")
+        if self.kv_unified:
+            server_args.append("--kv-unified")
         if self.server_slots:
             server_args.append("--slots")
         else:
index b6198edfc487cf64ca2b006bc3c23b10e1c41dcf..2bce2f4a47af9c501a028267d3ce53b843e1991a 100644 (file)
@@ -1212,7 +1212,7 @@ public:
             for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
                 auto * chunk = tokens.map_idx_to_media[it->first].get();
                 mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
-                map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
+                map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
             }
         }
     }
@@ -1244,6 +1244,7 @@ public:
     }
 
     void clear() {
+        map_idx_to_media.clear();
         tokens.clear();
     }