// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
}
}
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
+
if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
+ cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
const uint32_t size_base = kv_size;
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
+ // note: the SWA cache is always padded to 256 for performance
+ // https://github.com/ggml-org/llama.cpp/issues/17037
+ uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
def test_slot_ctx_not_exceeded():
global server
- server.n_ctx = 64
+ server.n_ctx = 256
server.start()
res = server.make_request("POST", "/completion", data={
- "prompt": "Hello " * 56,
+ "prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
def test_with_ctx_shift():
global server
- server.n_ctx = 64
+ server.n_ctx = 256
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
- "prompt": "Hello " * 56,
+ "prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
- "n_predict": 64,
+ "n_predict": 256,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
- assert res.body["tokens_predicted"] == 64
+ assert res.body["tokens_predicted"] == 256
assert res.body["truncated"] == True