]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : host-memory prompt caching (#16391)
authorGeorgi Gerganov <redacted>
Thu, 9 Oct 2025 15:54:51 +0000 (18:54 +0300)
committerGitHub <redacted>
Thu, 9 Oct 2025 15:54:51 +0000 (18:54 +0300)
* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <redacted>
* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <redacted>
common/arg.cpp
common/chat.h
common/common.h
src/llama-kv-cache.cpp
tools/server/server.cpp
tools/server/tests/unit/test_basic.py
tools/server/tests/unit/test_chat_completion.py
tools/server/tests/unit/test_completion.py
tools/server/tests/unit/test_ctx_shift.py
tools/server/utils.hpp

index 4204f6c6908fbd57c2007e77d8d88aea2e1c6263..d17645cf2f395cba2910f26ceeae3f11627f0845 100644 (file)
@@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.n_ctx_checkpoints = value;
         }
     ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--cache-ram", "-cram"}, "N",
+        string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
+        [](common_params & params, int value) {
+            params.cache_ram_mib = value;
+        }
+    ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--kv-unified", "-kvu"},
         string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
index a1afe574bd0cade3bbe085ca2740752455fd9c2c..f7b36ec711df422ffd66c2e9bf72355b25f5d44c 100644 (file)
@@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
 struct common_chat_msg {
     std::string role;
     std::string content;
-    std::vector<common_chat_msg_content_part> content_parts = {};
-    std::vector<common_chat_tool_call> tool_calls = {};
+    std::vector<common_chat_msg_content_part> content_parts;
+    std::vector<common_chat_tool_call> tool_calls;
     std::string reasoning_content;
     std::string tool_name;
     std::string tool_call_id;
@@ -44,7 +44,7 @@ struct common_chat_msg {
     bool empty() const {
         return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
     }
-    void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
+    void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
         for (auto i = 0u; i < tool_calls.size(); i++) {
             if (ids_cache.size() <= i) {
                 auto id = tool_calls[i].id;
index 0d3638c9c6228c3fce31aba1c8e22a8ea7645d86..040a44ebd89b0de8cd6d9e8718f3aa09a371af18 100644 (file)
@@ -378,7 +378,7 @@ struct common_params {
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
     bool no_perf           = false; // disable performance metrics
-    bool ctx_shift         = false;  // context shift on infinite text generation
+    bool ctx_shift         = false; // context shift on infinite text generation
     bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
     bool kv_unified        = false; // enable unified KV cache
 
@@ -425,7 +425,8 @@ struct common_params {
     int32_t timeout_write     = timeout_read; // http write timeout in seconds
     int32_t n_threads_http    = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
     int32_t n_cache_reuse     = 0;            // min chunk size to reuse from the cache via KV shifting
-    int32_t n_ctx_checkpoints = 3;            // max number of context checkpoints per slot
+    int32_t n_ctx_checkpoints = 8;            // max number of context checkpoints per slot
+    int32_t cache_ram_mib     = 8192;         // 0 = no limit, 1 = 1 MiB, etc.
 
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
index 816f2d5de592b9949cfab53f071e0acae80ace20..736693e174527d76aeaabd7f27855d72e033ae66 100644 (file)
@@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache(
             throw std::runtime_error("failed to create ggml context for kv cache");
         }
 
-        ggml_tensor * k;
-        ggml_tensor * v;
-
-        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
-        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
+        ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
+        ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
 
         ggml_format_name(k, "cache_k_l%d", il);
         ggml_format_name(v, "cache_v_l%d", il);
index de6e1a322b2c22c0e6d4dd707e3a7076c22b6923..41ecb279feb890cbb82134a0ef0589f113e648a0 100644 (file)
@@ -9,7 +9,6 @@
 #include "sampling.h"
 #include "speculative.h"
 #include "mtmd.h"
-#include "mtmd-helper.h"
 
 // mime type for sending response
 #define MIMETYPE_JSON "application/json; charset=utf-8"
@@ -158,7 +157,6 @@ struct slot_params {
 
         if (only_metrics) {
             return json {
-                {"n_predict",                 n_predict},     // Server configured n_predict
                 {"seed",                      sampling.seed},
                 {"temperature",               sampling.temp},
                 {"dynatemp_range",            sampling.dynatemp_range},
@@ -181,7 +179,8 @@ struct slot_params {
                 {"mirostat",                  sampling.mirostat},
                 {"mirostat_tau",              sampling.mirostat_tau},
                 {"mirostat_eta",              sampling.mirostat_eta},
-                {"max_tokens",                n_predict}, // User configured n_predict
+                {"max_tokens",                n_predict},
+                {"n_predict",                 n_predict}, // TODO: deduplicate?
                 {"n_keep",                    n_keep},
                 {"n_discard",                 n_discard},
                 {"ignore_eos",                sampling.ignore_eos},
@@ -209,7 +208,6 @@ struct slot_params {
         }
 
         return json {
-            {"n_predict",                 n_predict},     // Server configured n_predict
             {"seed",                      sampling.seed},
             {"temperature",               sampling.temp},
             {"dynatemp_range",            sampling.dynatemp_range},
@@ -234,7 +232,8 @@ struct slot_params {
             {"mirostat_tau",              sampling.mirostat_tau},
             {"mirostat_eta",              sampling.mirostat_eta},
             {"stop",                      antiprompt},
-            {"max_tokens",                n_predict}, // User configured n_predict
+            {"max_tokens",                n_predict},
+            {"n_predict",                 n_predict}, // TODO: deduplicate?
             {"n_keep",                    n_keep},
             {"n_discard",                 n_discard},
             {"ignore_eos",                sampling.ignore_eos},
@@ -265,15 +264,15 @@ struct server_task {
     int id    = -1; // to be filled by server_queue
     int index = -1; // used when there are multiple prompts (batch request)
 
-    server_task_type type;
-
     // used by SERVER_TASK_TYPE_CANCEL
     int id_target = -1;
+    int id_slot   = -1;
 
     // used by SERVER_TASK_TYPE_INFERENCE
     slot_params   params;
-    server_tokens prompt_tokens;
-    int id_selected_slot = -1;
+    server_tokens tokens;
+
+    server_task_type type;
 
     // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
     struct slot_action {
@@ -289,6 +288,8 @@ struct server_task {
     // used by SERVER_TASK_TYPE_SET_LORA
     std::vector<common_adapter_lora_info> set_lora;
 
+    server_task() = default;
+
     server_task(server_task_type type) : type(type) {}
 
     static slot_params params_from_json_cmpl(
@@ -305,6 +306,7 @@ struct server_task {
         defaults.sampling    = params_base.sampling;
         defaults.speculative = params_base.speculative;
         defaults.n_keep      = params_base.n_keep;
+        defaults.n_predict   = params_base.n_predict;
         defaults.antiprompt  = params_base.antiprompt;
 
         // enabling this will output extra debug information in the HTTP responses from the server
@@ -323,32 +325,32 @@ struct server_task {
         params.n_discard        = json_value(data,       "n_discard",          defaults.n_discard);
       //params.t_max_prompt_ms  = json_value(data,       "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
         params.t_max_predict_ms = json_value(data,       "t_max_predict_ms",   defaults.t_max_predict_ms);
-        params.response_fields  = json_value(data,       "response_fields",   std::vector<std::string>());
-
-        params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
-        params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
-        params.sampling.min_p              = json_value(data, "min_p",              defaults.sampling.min_p);
-        params.sampling.top_n_sigma        = json_value(data, "top_n_sigma",        defaults.sampling.top_n_sigma);
-        params.sampling.xtc_probability    = json_value(data, "xtc_probability",    defaults.sampling.xtc_probability);
-        params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",      defaults.sampling.xtc_threshold);
-        params.sampling.typ_p              = json_value(data, "typical_p",          defaults.sampling.typ_p);
-        params.sampling.temp               = json_value(data, "temperature",        defaults.sampling.temp);
-        params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",     defaults.sampling.dynatemp_range);
-        params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",  defaults.sampling.dynatemp_exponent);
-        params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",      defaults.sampling.penalty_last_n);
-        params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",     defaults.sampling.penalty_repeat);
-        params.sampling.penalty_freq       = json_value(data, "frequency_penalty",  defaults.sampling.penalty_freq);
-        params.sampling.penalty_present    = json_value(data, "presence_penalty",   defaults.sampling.penalty_present);
-        params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",     defaults.sampling.dry_multiplier);
-        params.sampling.dry_base           = json_value(data, "dry_base",           defaults.sampling.dry_base);
-        params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
-        params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
-        params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
-        params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
-        params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
-        params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
-        params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
-        params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
+        params.response_fields  = json_value(data,       "response_fields",    std::vector<std::string>());
+
+        params.sampling.top_k              = json_value(data, "top_k",               defaults.sampling.top_k);
+        params.sampling.top_p              = json_value(data, "top_p",               defaults.sampling.top_p);
+        params.sampling.min_p              = json_value(data, "min_p",               defaults.sampling.min_p);
+        params.sampling.top_n_sigma        = json_value(data, "top_n_sigma",         defaults.sampling.top_n_sigma);
+        params.sampling.xtc_probability    = json_value(data, "xtc_probability",     defaults.sampling.xtc_probability);
+        params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",       defaults.sampling.xtc_threshold);
+        params.sampling.typ_p              = json_value(data, "typical_p",           defaults.sampling.typ_p);
+        params.sampling.temp               = json_value(data, "temperature",         defaults.sampling.temp);
+        params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",      defaults.sampling.dynatemp_range);
+        params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",   defaults.sampling.dynatemp_exponent);
+        params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",       defaults.sampling.penalty_last_n);
+        params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",      defaults.sampling.penalty_repeat);
+        params.sampling.penalty_freq       = json_value(data, "frequency_penalty",   defaults.sampling.penalty_freq);
+        params.sampling.penalty_present    = json_value(data, "presence_penalty",    defaults.sampling.penalty_present);
+        params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",      defaults.sampling.dry_multiplier);
+        params.sampling.dry_base           = json_value(data, "dry_base",            defaults.sampling.dry_base);
+        params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length",  defaults.sampling.dry_allowed_length);
+        params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n",  defaults.sampling.dry_penalty_last_n);
+        params.sampling.mirostat           = json_value(data, "mirostat",            defaults.sampling.mirostat);
+        params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",        defaults.sampling.mirostat_tau);
+        params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",        defaults.sampling.mirostat_eta);
+        params.sampling.seed               = json_value(data, "seed",                defaults.sampling.seed);
+        params.sampling.n_probs            = json_value(data, "n_probs",             defaults.sampling.n_probs);
+        params.sampling.min_keep           = json_value(data, "min_keep",            defaults.sampling.min_keep);
         params.post_sampling_probs         = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
 
         params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
@@ -690,7 +692,7 @@ struct server_task_result {
 // using shared_ptr for polymorphism of server_task_result
 using server_task_result_ptr = std::unique_ptr<server_task_result>;
 
-inline std::string stop_type_to_str(stop_type type) {
+static inline std::string stop_type_to_str(stop_type type) {
     switch (type) {
         case STOP_TYPE_EOS:   return "eos";
         case STOP_TYPE_WORD:  return "word";
@@ -764,13 +766,6 @@ struct completion_token_output {
     }
 };
 
-struct ctx_checkpoint {
-    llama_pos pos_min;
-    llama_pos pos_max;
-
-    std::vector<uint8_t> data;
-};
-
 struct server_task_result_cmpl_final : server_task_result {
     int index = 0;
 
@@ -797,11 +792,12 @@ struct server_task_result_cmpl_final : server_task_result {
     slot_params generation_params;
 
     // OAI-compat fields
-    bool               verbose                  = false;
-    oaicompat_type     oaicompat                = OAICOMPAT_TYPE_NONE;
-    std::string        oaicompat_model;
-    std::string        oaicompat_cmpl_id;
-    common_chat_msg    oaicompat_msg;
+    bool            verbose   = false;
+    oaicompat_type  oaicompat = OAICOMPAT_TYPE_NONE;
+    std::string     oaicompat_model;
+    std::string     oaicompat_cmpl_id;
+    common_chat_msg oaicompat_msg;
+
     std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
 
     virtual int get_index() override {
@@ -1373,17 +1369,17 @@ struct server_task_result_slot_save_load : server_task_result {
                     { "save_ms", t_ms }
                 }},
             };
-        } else {
-            return json {
-                { "id_slot",    id_slot },
-                { "filename",   filename },
-                { "n_restored", n_tokens },
-                { "n_read",     n_bytes },
-                { "timings", {
-                    { "restore_ms", t_ms }
-                }},
-            };
         }
+
+        return json {
+            { "id_slot",    id_slot },
+            { "filename",   filename },
+            { "n_restored", n_tokens },
+            { "n_read",     n_bytes },
+            { "timings", {
+                { "restore_ms", t_ms }
+            }},
+        };
     }
 };
 
@@ -1404,15 +1400,218 @@ struct server_task_result_apply_lora : server_task_result {
     }
 };
 
+struct server_prompt_checkpoint {
+    llama_pos pos_min;
+    llama_pos pos_max;
+
+    std::vector<uint8_t> data;
+
+    size_t size() const {
+        return data.size();
+    }
+};
+
+struct server_prompt {
+    server_tokens tokens;
+
+    std::vector<uint8_t> data;
+
+    std::list<server_prompt_checkpoint> checkpoints;
+
+    size_t size() const {
+        size_t res = data.size();
+
+        for (const auto & checkpoint : checkpoints) {
+            res += checkpoint.size();
+        }
+
+        return res;
+    }
+
+    int n_tokens() const {
+        return tokens.size();
+    }
+};
+
+struct server_prompt_cache {
+    server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
+        this->limit_size   = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
+        this->limit_tokens = limit_tokens;
+    }
+
+    std::list<server_prompt> states;
+
+    // in bytes, 0 = no limit
+    size_t limit_size = 0;
+
+    // in tokens, 0 = no limit
+    size_t limit_tokens = 0;
+
+    size_t size() const {
+        size_t res = 0;
+
+        for (const auto & state : states) {
+            res += state.size();
+        }
+
+        return res;
+    }
+
+    size_t n_tokens() const {
+        size_t res = 0;
+
+        for (const auto & state : states) {
+            res += state.n_tokens();
+        }
+
+        return res;
+    }
+
+    server_prompt * alloc(const server_prompt & prompt, size_t state_size) {
+        // first check if the current state is contained fully in the cache
+        for (auto it = states.begin(); it != states.end(); ++it) {
+            const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
+
+            if (cur_lcp_len == (int) prompt.tokens.size()) {
+                SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
+                return nullptr;
+            }
+        }
+
+        // next, remove any cached prompts that are fully contained in the current prompt
+        for (auto it = states.begin(); it != states.end();) {
+            const int len = it->tokens.get_common_prefix(prompt.tokens);
+
+            if (len == (int) it->tokens.size()) {
+                SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
+
+                it = states.erase(it);
+            } else {
+                ++it;
+            }
+        }
+
+        std::vector<uint8_t> state_data;
+
+        // check if we can allocate enough memory for the new state
+        try {
+            state_data.resize(state_size);
+        } catch (const std::bad_alloc & e) {
+            SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
+
+            limit_size = std::max<size_t>(1, 0.4*size());
+
+            SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
+
+            update();
+
+            return nullptr;
+        }
+
+        // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
+        auto & cur = states.emplace_back();
+        cur = {
+            /*.tokens      =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
+            /*.data        =*/ std::move(state_data),
+            /*.checkpoints =*/ prompt.checkpoints,
+        };
+
+        return &cur;
+    }
+
+    bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
+        const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
+
+        float f_keep_best = float(lcp_best) / prompt.tokens.size();
+        float sim_best    = float(lcp_best) / tokens_new.size();
+
+        SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+        auto it_best = states.end();
+
+        // find the most similar cached prompt, that would also preserve the most context
+        for (auto it = states.begin(); it != states.end(); ++it) {
+            const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
+
+            const float f_keep_cur = float(lcp_cur) / it->tokens.size();
+            const float sim_cur    = float(lcp_cur) / tokens_new.size();
+
+            // don't trash large prompts
+            if (f_keep_cur < 0.25f) {
+                continue;
+            }
+
+            if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
+                f_keep_best = f_keep_cur;
+                sim_best    = sim_cur;
+
+                it_best = it;
+            }
+        }
+
+        if (it_best != states.end()) {
+            SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+            const size_t size = it_best->data.size();
+            const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
+            if (n != size) {
+                SRV_WRN("failed to restore state with size %zu\n", size);
+
+                return false;
+            }
+
+            it_best->data.clear();
+            it_best->data.shrink_to_fit();
+
+            prompt = std::move(*it_best);
+
+            states.erase(it_best);
+        }
+
+        return true;
+    }
+
+    void update() {
+        if (limit_size > 0) {
+            // always keep at least one state, regardless of the limits
+            while (states.size() > 1 && size() > limit_size) {
+                if (states.empty()) {
+                    break;
+                }
+
+                SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+                states.pop_front();
+            }
+        }
+
+        if (limit_tokens > 0) {
+            while (states.size() > 1 && n_tokens() > limit_tokens) {
+                if (states.empty()) {
+                    break;
+                }
+
+                SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+                states.pop_front();
+            }
+        }
+
+        SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n",
+                states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
+
+        for (const auto & state : states) {
+            SRV_WRN("   - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
+        }
+    }
+};
+
 struct server_slot {
     int id;
-    int id_task = -1;
-
-    // only used for completion/embedding/infill/rerank
-    server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
 
     llama_batch batch_spec = {};
 
+    // TODO: change to unique_ptrs for consistency:
     llama_context * ctx = nullptr;
     llama_context * ctx_dft = nullptr;
 
@@ -1421,15 +1620,8 @@ struct server_slot {
 
     common_speculative * spec = nullptr;
 
-    std::vector<common_adapter_lora_info> lora;
-    int32_t alora_invocation_start = -1;
-
-    // the index relative to completion multi-task request
-    size_t index = 0;
-
-    struct slot_params params;
-
-    slot_state state = SLOT_STATE_IDLE;
+    std::unique_ptr<const server_task> task;
+    std::unique_ptr<const server_task> task_prev; // used for debugging
 
     // used to determine the slot that has been used the longest
     int64_t t_last_used = -1;
@@ -1437,38 +1629,66 @@ struct server_slot {
     // generation props
     int32_t n_ctx       = 0;  // context size per slot
     int32_t n_past      = 0;
+    int32_t n_keep      = 0;
     int32_t n_decoded   = 0;
     int32_t n_remaining = -1;
     int32_t i_batch     = -1;
-    int32_t n_predict   = -1; // TODO: disambiguate from params.n_predict
 
-    // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
-    int32_t n_prompt_tokens           = 0;
     int32_t n_prompt_tokens_cache     = 0;
     int32_t n_prompt_tokens_processed = 0;
 
-    // input prompt tokens
-    server_tokens prompt_tokens;
+    int32_t n_prompt_tokens() const {
+        return task->tokens.size();
+    }
 
     size_t last_nl_pos = 0;
 
     std::string  generated_text;
     llama_tokens generated_tokens;
-    common_chat_msg chat_msg;
 
-    server_tokens cache_tokens;
+    common_chat_msg chat_msg;
 
     std::vector<completion_token_output> generated_token_probs;
 
-    std::vector<ctx_checkpoint> ctx_checkpoints;
-
     bool has_next_token = true;
     bool has_new_line   = false;
     bool truncated      = false;
+
     stop_type stop;
 
     std::string stopping_word;
 
+    // state
+    slot_state state = SLOT_STATE_IDLE;
+
+    server_prompt prompt;
+
+    void prompt_save(server_prompt_cache & prompt_cache) const {
+        assert(prompt.data.size() == 0);
+
+        const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
+
+        SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
+                (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
+
+        auto * cur = prompt_cache.alloc(prompt, cur_size);
+        if (cur == nullptr) {
+            return;
+        }
+
+        llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
+    }
+
+    void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
+        bool res = prompt_cache.load(prompt, tokens, ctx, id);
+        if (!res) {
+            SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
+        }
+    }
+
+    std::vector<common_adapter_lora_info> lora;
+    int32_t alora_invocation_start = -1;
+
     // sampling
     json json_schema;
 
@@ -1480,7 +1700,7 @@ struct server_slot {
     std::vector<std::string> generated_tool_call_ids;
 
     // stats
-    size_t n_sent_text        = 0; // number of sent text character
+    size_t n_sent_text = 0; // number of sent text character
 
     int64_t t_start_process_prompt;
     int64_t t_start_generation;
@@ -1497,19 +1717,17 @@ struct server_slot {
     void reset() {
         SLT_DBG(*this, "%s", "\n");
 
-        n_prompt_tokens       = 0;
         n_prompt_tokens_cache = 0;
 
-        last_nl_pos        = 0;
-        generated_text     = "";
-        has_new_line       = false;
-        truncated          = false;
-        stop               = STOP_TYPE_NONE;
-        stopping_word      = "";
-        n_past             = 0;
-        n_sent_text        = 0;
-        task_type          = SERVER_TASK_TYPE_COMPLETION;
-        chat_format        = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+        last_nl_pos    = 0;
+        generated_text = "";
+        has_new_line   = false;
+        truncated      = false;
+        stop           = STOP_TYPE_NONE;
+        stopping_word  = "";
+        n_past         = 0;
+        n_sent_text    = 0;
+        chat_format    = COMMON_CHAT_FORMAT_CONTENT_ONLY;
 
         generated_tokens.clear();
         generated_token_probs.clear();
@@ -1521,16 +1739,23 @@ struct server_slot {
         n_draft_total = 0;
         n_draft_accepted = 0;
 
+        task.reset();
+        task_prev.reset();
+
         // clear alora start
         alora_invocation_start = -1;
     }
 
     bool need_embd() const {
-        return server_task_type_need_embd(task_type);
+        GGML_ASSERT(task);
+
+        return server_task_type_need_embd(task->type);
     }
 
     bool need_logits() const {
-        return server_task_type_need_logits(task_type);
+        GGML_ASSERT(task);
+
+        return server_task_type_need_logits(task->type);
     }
 
     // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
@@ -1542,18 +1767,22 @@ struct server_slot {
     }
 
     bool can_batch_with(server_slot & other_slot) const {
-        return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
+        GGML_ASSERT(task);
+
+        return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
     }
 
     bool has_budget(const common_params & global_params) {
-        if (params.n_predict == -1 && global_params.n_predict == -1) {
+        GGML_ASSERT(task);
+
+        if (task->params.n_predict == -1 && global_params.n_predict == -1) {
             return true; // limitless
         }
 
         n_remaining = -1;
 
-        if (params.n_predict != -1) {
-            n_remaining = params.n_predict - n_decoded;
+        if (task->params.n_predict != -1) {
+            n_remaining = task->params.n_predict - n_decoded;
         } else if (global_params.n_predict != -1) {
             n_remaining = global_params.n_predict - n_decoded;
         }
@@ -1566,7 +1795,7 @@ struct server_slot {
     }
 
     bool can_speculate() const {
-        return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+        return ctx_dft;
     }
 
     void add_token(const completion_token_output & token) {
@@ -1579,11 +1808,17 @@ struct server_slot {
 
     void release() {
         if (is_processing()) {
+            GGML_ASSERT(task);
+
             SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
 
             t_last_used = ggml_time_us();
             t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
             state = SLOT_STATE_IDLE;
+
+            task_prev = std::move(task);
+            task.reset();
+
             callback_on_release(id);
         }
     }
@@ -1592,19 +1827,19 @@ struct server_slot {
         result_timings timings;
         timings.cache_n = n_prompt_tokens_cache;
 
-        timings.prompt_n = n_prompt_tokens_processed;
-        timings.prompt_ms = t_prompt_processing;
+        timings.prompt_n            = n_prompt_tokens_processed;
+        timings.prompt_ms           = t_prompt_processing;
         timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
-        timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
+        timings.prompt_per_second   = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
 
-        timings.predicted_n = n_decoded;
-        timings.predicted_ms = t_token_generation;
+        timings.predicted_n            = n_decoded;
+        timings.predicted_ms           = t_token_generation;
         timings.predicted_per_token_ms = t_token_generation / n_decoded;
-        timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
+        timings.predicted_per_second   = 1e3 / t_token_generation * n_decoded;
 
         // Add speculative metrics
         if (n_draft_total > 0) {
-            timings.draft_n = n_draft_total;
+            timings.draft_n          = n_draft_total;
             timings.draft_n_accepted = n_draft_accepted;
         }
 
@@ -1612,14 +1847,16 @@ struct server_slot {
     }
 
     const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
+        GGML_ASSERT(task);
+
         auto previous_msg = chat_msg;
         SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
         auto new_msg = common_chat_parse(
             generated_text,
             /* is_partial= */ stop != STOP_TYPE_EOS,
-            params.oaicompat_chat_syntax);
+            task->params.oaicompat_chat_syntax);
         if (!new_msg.empty()) {
-            new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
+            new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
             chat_msg = new_msg;
             diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
         }
@@ -1627,9 +1864,11 @@ struct server_slot {
     }
 
     size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
+        GGML_ASSERT(task);
+
         size_t stop_pos = std::string::npos;
 
-        for (const std::string & word : params.antiprompt) {
+        for (const std::string & word : task->params.antiprompt) {
             size_t pos;
 
             if (is_full_stop) {
@@ -1682,43 +1921,36 @@ struct server_slot {
     }
 
     json to_json(bool only_metrics = false) const {
-        if (only_metrics) {
-            return json {
-                {"id",            id},
-                {"id_task",       id_task},
-                {"n_ctx",         n_ctx},
-                {"speculative",   can_speculate()},
-                {"is_processing", is_processing()},
-                {"params",        params.to_json(true)},
-                {"next_token",
-                    {
-                        {"has_next_token", has_next_token},
-                        {"has_new_line",   has_new_line},
-                        {"n_remain",       n_remaining},
-                        {"n_decoded",      n_decoded},
-                    }
-                },
-            };
-        }
+        json res;
 
-        return json {
+        res = {
             {"id",            id},
-            {"id_task",       id_task},
             {"n_ctx",         n_ctx},
             {"speculative",   can_speculate()},
             {"is_processing", is_processing()},
-            {"params",        params.to_json()},
-            {"prompt",        prompt_tokens.detokenize(ctx, true)},
-            {"next_token",
+        };
+
+        const auto & ptask = task ? task : task_prev;
+
+        if (ptask) {
+            res["id_task"] = ptask->id;
+            res["params"] = ptask->params.to_json(only_metrics);
+            res["next_token"] = {
                 {
                     {"has_next_token", has_next_token},
                     {"has_new_line",   has_new_line},
                     {"n_remain",       n_remaining},
                     {"n_decoded",      n_decoded},
-                    {"stopping_word",  stopping_word},
                 }
-            },
-        };
+            };
+
+            if (!only_metrics) {
+                res["prompt"] = ptask->tokens.detokenize(ctx, true);
+                res["generated"] = generated_text;
+            }
+        }
+
+        return res;
     }
 };
 
@@ -2109,11 +2341,14 @@ struct server_context {
 
     // slots / clients
     std::vector<server_slot> slots;
-    json default_generation_settings_for_props;
+
+    int slots_debug = 0;
 
     server_queue    queue_tasks;
     server_response queue_results;
 
+    std::unique_ptr<server_prompt_cache> prompt_cache;
+
     server_metrics metrics;
 
     // Necessary similarity of prompt for slot selection
@@ -2268,9 +2503,8 @@ struct server_context {
             slot.id = i;
             slot.ctx = ctx;
             slot.n_ctx = n_ctx_slot;
-            slot.n_predict = params_base.n_predict;
             slot.mctx = mctx;
-            slot.cache_tokens.has_mtmd = mctx != nullptr;
+            slot.prompt.tokens.has_mtmd = mctx != nullptr;
 
             if (model_dft) {
                 slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
@@ -2286,16 +2520,13 @@ struct server_context {
                     SRV_ERR("%s", "failed to create speculator\n");
                     return;
                 }
-                for (auto &pair : params_base.speculative.replacements) {
+                for (auto & pair : params_base.speculative.replacements) {
                     common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
                 }
             }
 
             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
 
-            slot.params.sampling = params_base.sampling;
-            slot.params.n_keep = params_base.n_keep;
-
             slot.callback_on_release = [this](int) {
                 queue_tasks.pop_deferred_task();
             };
@@ -2305,7 +2536,14 @@ struct server_context {
             slots.push_back(std::move(slot));
         }
 
-        default_generation_settings_for_props = slots[0].to_json();
+        {
+            const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
+            slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
+
+            if (slots_debug) {
+                SRV_WRN("slots debug = %d\n", slots_debug);
+            }
+        }
 
         // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
         // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
@@ -2316,11 +2554,25 @@ struct server_context {
 
         metrics.init();
 
+        if (params_base.cache_ram_mib != 0) {
+            if (params_base.cache_ram_mib < 0) {
+                SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
+            } else {
+                SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
+            }
+            SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
+
+            prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
+        } else {
+            SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
+        }
+        SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
+
         // thinking is enabled if:
         // 1. It's not explicitly disabled (reasoning_budget == 0)
         // 2. The chat template supports it
         const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
-        SRV_INF("Enable thinking? %d\n", enable_thinking);
+        SRV_INF("thinking = %d\n", enable_thinking);
 
         oai_parser_opt = {
             /* use_jinja             */ params_base.use_jinja,
@@ -2347,10 +2599,11 @@ struct server_context {
     server_slot * get_available_slot(const server_task & task) {
         server_slot * ret = nullptr;
 
+        bool update_cache = false;
+
         // find the slot that has at least n% prompt similarity
         if (ret == nullptr && slot_prompt_similarity != 0.0f) {
-            int lcs_len = 0;
-            float similarity = 0;
+            float sim_best = 0;
 
             for (server_slot & slot : slots) {
                 // skip the slot if it is not available
@@ -2358,27 +2611,34 @@ struct server_context {
                     continue;
                 }
 
+                const auto & tokens = slot.prompt.tokens;
+
                 // skip the slot if it does not contains cached tokens
-                if (slot.cache_tokens.empty()) {
+                if (tokens.empty()) {
                     continue;
                 }
 
-                // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
-                int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
-
-                // fraction of the common subsequence length compared to the current slot's prompt length
-                float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
+                // fraction of the Longest Common Prefix length with respect to the input prompt length
+                const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
 
                 // select the current slot if the criteria match
-                if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
-                    lcs_len = cur_lcs_len;
-                    similarity = cur_similarity;
+                if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
+                    sim_best = sim_cur;
+
                     ret = &slot;
                 }
             }
 
             if (ret != nullptr) {
-                SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity);
+                const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
+
+                SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
+                        sim_best, slot_prompt_similarity, f_keep);
+
+                // if we are about to lose a large portion of the existing context - save it in the prompt cache
+                if (f_keep < 0.5f) {
+                    update_cache = true;
+                }
             }
         }
 
@@ -2401,6 +2661,36 @@ struct server_context {
 
             if (ret != nullptr) {
                 SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
+
+                update_cache = true;
+            }
+        }
+
+        if (ret) {
+            const auto & tokens = ret->prompt.tokens;
+
+            update_cache = update_cache && prompt_cache;
+
+            // cache prompts only for completion tasks
+            update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
+
+            // don't update the cache if the slot's context is empty
+            update_cache = update_cache && tokens.size() > 0;
+
+            // TODO: mtmd does not support prompt cache
+            update_cache = update_cache && (ret->mctx == nullptr);
+
+            if (update_cache) {
+                SRV_WRN("%s", "updating prompt cache\n");
+
+                const int64_t t_start = ggml_time_us();
+
+                ret->prompt_save(*prompt_cache);
+                ret->prompt_load(*prompt_cache, task.tokens);
+
+                prompt_cache->update();
+
+                SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
             }
         }
 
@@ -2409,27 +2699,21 @@ struct server_context {
 
     bool launch_slot_with_task(server_slot & slot, server_task && task) {
         slot.reset();
-        slot.id_task       = task.id;
-        slot.index         = task.index;
-        slot.task_type     = task.type;
-        slot.params        = std::move(task.params);
-        slot.prompt_tokens = std::move(task.prompt_tokens);
 
-        if (!are_lora_equal(slot.params.lora, slot.lora)) {
+        if (!are_lora_equal(task.params.lora, slot.lora)) {
             // if lora has changed, check to see if the cache should be cleared
-            if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
-                SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
-                slot.cache_tokens.clear();
+            if (lora_should_clear_cache(slot.lora, task.params.lora)) {
+                SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
+                slot.prompt.tokens.clear();
             } else {
-                SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
+                SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size());
             }
-            slot.lora = slot.params.lora;
+            slot.lora = task.params.lora;
         }
 
         // if using alora, make sure it's only a single one requested and active
-        size_t alora_invocation_start = slot.prompt_tokens.size();
+        size_t alora_invocation_start = task.tokens.size();
         if (lora_all_alora(slot.lora)) {
-
             const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
             // TODO: This will error out if a user requests two aloras, but only
             // provides the activation string for one. We could, instead search
@@ -2448,10 +2732,10 @@ struct server_context {
             // scan backwards through the prompt tokens to find the last
             // occurrence of the invocation sequence
             int match_idx = static_cast<int>(n_invocation_tokens) - 1;
-            for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
+            for (int i = task.tokens.size() - 1; i >= 0; --i) {
                 // the token in this position matches the next token to find in
                 // the invocation sequence
-                if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
+                if (task.tokens[i] == invocation_tokens[match_idx]) {
                     // if it's a full match, we've found the start
                     if (match_idx == 0) {
                         alora_invocation_start = i;
@@ -2466,7 +2750,7 @@ struct server_context {
             }
 
             // if the activation string is not found, disable the alora
-            if (alora_invocation_start == slot.prompt_tokens.size()) {
+            if (alora_invocation_start == task.tokens.size()) {
                 SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
                 slot.lora[enabled_ids[0]].scale = 0.0f;
             } else {
@@ -2475,24 +2759,20 @@ struct server_context {
             }
         }
 
-        if (!slot.prompt_tokens.validate(ctx)) {
+        if (!task.tokens.validate(ctx)) {
             send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
             return false;
         }
-        SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
-        if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
-            // Might be better to reject the request with a 400 ?
-            SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
-            slot.params.n_predict = slot.n_predict;
-        }
+        SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
+        // initialize samplers
         {
             if (slot.smpl != nullptr) {
                 common_sampler_free(slot.smpl);
             }
 
-            slot.smpl = common_sampler_init(model, slot.params.sampling);
+            slot.smpl = common_sampler_init(model, task.params.sampling);
             if (slot.smpl == nullptr) {
                 // for now, the only error that may happen here is invalid grammar
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -2500,12 +2780,15 @@ struct server_context {
             }
         }
 
+        // initialize draft batch
         if (slot.ctx_dft) {
             llama_batch_free(slot.batch_spec);
 
-            slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
+            slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
         }
 
+        slot.task = std::make_unique<const server_task>(std::move(task));
+
         slot.state = SLOT_STATE_STARTED;
 
         SLT_INF(slot, "%s", "processing task\n");
@@ -2527,7 +2810,7 @@ struct server_context {
         slot.sampled = result.tok;
 
         slot.generated_text += token_str;
-        if (slot.params.return_tokens) {
+        if (slot.task->params.return_tokens) {
             slot.generated_tokens.push_back(result.tok);
         }
         slot.has_next_token = true;
@@ -2564,7 +2847,7 @@ struct server_context {
             }
 
             slot.add_token(result);
-            if (slot.params.stream) {
+            if (slot.task->params.stream) {
                 send_partial_response(slot, result, false);
             }
         }
@@ -2586,12 +2869,12 @@ struct server_context {
             slot.stop           = STOP_TYPE_LIMIT;
             slot.has_next_token = false;
 
-            SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
+            SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
         }
 
         if (slot.has_new_line) {
             // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
-            if (slot.params.n_indent > 0) {
+            if (slot.task->params.n_indent > 0) {
                 // check the current indentation
                 // TODO: improve by not doing it more than once for each new line
                 if (slot.last_nl_pos > 0) {
@@ -2603,7 +2886,7 @@ struct server_context {
                         pos++;
                     }
 
-                    if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
+                    if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
                         slot.stop           = STOP_TYPE_LIMIT;
                         slot.has_next_token = false;
 
@@ -2630,11 +2913,11 @@ struct server_context {
             slot.has_new_line = true;
 
             // if we have seen a new line, we stop after a certain time limit, but only upon another new line
-            if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+            if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
                 slot.stop           = STOP_TYPE_LIMIT;
                 slot.has_next_token = false;
 
-                SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+                SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
             }
         }
 
@@ -2645,7 +2928,7 @@ struct server_context {
             slot.has_next_token = false;
 
             SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
-                    slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
+                    slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
         }
 
         if (llama_vocab_is_eog(vocab, result.tok)) {
@@ -2657,7 +2940,7 @@ struct server_context {
 
         const auto n_ctx_train = llama_model_n_ctx_train(model);
 
-        if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+        if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
             slot.truncated      = true;
             slot.stop           = STOP_TYPE_LIMIT;
             slot.has_next_token = false; // stop prediction
@@ -2665,7 +2948,7 @@ struct server_context {
             SLT_WRN(slot,
                     "n_predict (%d) is set for infinite generation. "
                     "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
-                    slot.params.n_predict, n_ctx_train);
+                    slot.task->params.n_predict, n_ctx_train);
         }
 
         SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
@@ -2674,7 +2957,7 @@ struct server_context {
     }
 
     void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
-        size_t n_probs = slot.params.sampling.n_probs;
+        size_t n_probs = slot.task->params.sampling.n_probs;
         size_t n_vocab = llama_vocab_n_tokens(vocab);
 
         if (post_sampling) {
@@ -2728,7 +3011,7 @@ struct server_context {
     }
 
     void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
-        send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
+        send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
     }
 
     void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
@@ -2749,7 +3032,7 @@ struct server_context {
     }
 
     // if multimodal is enabled, send an error and return false
-    bool ensure_no_mtmd(const int id_task) {
+    bool check_no_mtmd(const int id_task) {
         if (mctx) {
             send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
             return false;
@@ -2760,14 +3043,14 @@ struct server_context {
     void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
         auto res = std::make_unique<server_task_result_cmpl_partial>();
 
-        res->id    = slot.id_task;
-        res->index = slot.index;
+        res->id    = slot.task->id;
+        res->index = slot.task->index;
 
         if (is_progress) {
             res->is_progress        = true;
-            res->progress.total     = slot.n_prompt_tokens;
+            res->progress.total     = slot.n_prompt_tokens();
             res->progress.cache     = slot.n_prompt_tokens_cache;
-            res->progress.processed = slot.cache_tokens.size();
+            res->progress.processed = slot.prompt.tokens.size();
             res->progress.time_ms   = (ggml_time_us() - slot.t_start_process_prompt / 1000);
         } else {
             res->content = tkn.text_to_send;
@@ -2777,21 +3060,21 @@ struct server_context {
         }
 
         res->n_decoded           = slot.n_decoded;
-        res->n_prompt_tokens     = slot.n_prompt_tokens;
-        res->post_sampling_probs = slot.params.post_sampling_probs;
+        res->n_prompt_tokens     = slot.n_prompt_tokens();
+        res->post_sampling_probs = slot.task->params.post_sampling_probs;
 
-        res->verbose               = slot.params.verbose;
-        res->oaicompat             = slot.params.oaicompat;
-        res->oaicompat_model       = slot.params.oaicompat_model;
-        res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
+        res->verbose           = slot.task->params.verbose;
+        res->oaicompat         = slot.task->params.oaicompat;
+        res->oaicompat_model   = slot.task->params.oaicompat_model;
+        res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
 
         // populate res.probs_output
-        if (slot.params.sampling.n_probs > 0) {
+        if (slot.task->params.sampling.n_probs > 0) {
             res->prob_output = tkn; // copy the token probs
         }
 
         // populate timings if this is final response or timings_per_token is enabled
-        if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
+        if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
             res->timings = slot.get_timings();
         }
 
@@ -2800,36 +3083,37 @@ struct server_context {
 
     void send_final_response(server_slot & slot) {
         auto res = std::make_unique<server_task_result_cmpl_final>();
-        res->id              = slot.id_task;
-        res->id_slot         = slot.id;
 
-        res->index           = slot.index;
+        res->id      = slot.task->id;
+        res->id_slot = slot.id;
+
+        res->index           = slot.task->index;
         res->content         = slot.generated_text;
         res->tokens          = std::move(slot.generated_tokens);
         res->timings         = slot.get_timings();
-        res->prompt          = slot.prompt_tokens.detokenize(ctx, true);
-        res->response_fields = std::move(slot.params.response_fields);
+        res->prompt          = slot.task->tokens.detokenize(ctx, true);
+        res->response_fields = std::move(slot.task->params.response_fields);
 
         res->truncated           = slot.truncated;
         res->n_decoded           = slot.n_decoded;
-        res->n_prompt_tokens     = slot.n_prompt_tokens;
+        res->n_prompt_tokens     = slot.n_prompt_tokens();
         res->n_tokens_cached     = slot.n_past;
         res->has_new_line        = slot.has_new_line;
         res->stopping_word       = slot.stopping_word;
         res->stop                = slot.stop;
-        res->post_sampling_probs = slot.params.post_sampling_probs;
+        res->post_sampling_probs = slot.task->params.post_sampling_probs;
 
-        res->verbose               = slot.params.verbose;
-        res->stream                = slot.params.stream;
-        res->include_usage         = slot.params.include_usage;
-        res->oaicompat             = slot.params.oaicompat;
-        res->oaicompat_model       = slot.params.oaicompat_model;
-        res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
-        res->oaicompat_msg         = slot.update_chat_msg(res->oaicompat_msg_diffs);
+        res->verbose           = slot.task->params.verbose;
+        res->stream            = slot.task->params.stream;
+        res->include_usage     = slot.task->params.include_usage;
+        res->oaicompat         = slot.task->params.oaicompat;
+        res->oaicompat_model   = slot.task->params.oaicompat_model;
+        res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
+        res->oaicompat_msg     = slot.update_chat_msg(res->oaicompat_msg_diffs);
 
         // populate res.probs_output
-        if (slot.params.sampling.n_probs > 0) {
-            if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
+        if (slot.task->params.sampling.n_probs > 0) {
+            if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
                 const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
 
                 size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
@@ -2843,17 +3127,17 @@ struct server_context {
             }
         }
 
-        res->generation_params = slot.params; // copy the parameters
+        res->generation_params = slot.task->params; // copy the parameters
 
         queue_results.send(std::move(res));
     }
 
     void send_embedding(const server_slot & slot, const llama_batch & batch) {
         auto res = std::make_unique<server_task_result_embd>();
-        res->id        = slot.id_task;
-        res->index     = slot.index;
-        res->n_tokens  = slot.n_prompt_tokens;
-        res->oaicompat = slot.params.oaicompat;
+        res->id        = slot.task->id;
+        res->index     = slot.task->index;
+        res->n_tokens  = slot.n_prompt_tokens();
+        res->oaicompat = slot.task->params.oaicompat;
 
         const int n_embd = llama_model_n_embd(model);
 
@@ -2880,12 +3164,12 @@ struct server_context {
 
             // normalize only when there is pooling
             if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
-                common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize);
+                common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
                 res->embedding.push_back(embd_res);
                 break;
-            } else {
-                res->embedding.emplace_back(embd, embd + n_embd);
             }
+
+            res->embedding.emplace_back(embd, embd + n_embd);
         }
 
         SLT_DBG(slot, "%s", "sending embeddings\n");
@@ -2895,9 +3179,9 @@ struct server_context {
 
     void send_rerank(const server_slot & slot, const llama_batch & batch) {
         auto res = std::make_unique<server_task_result_rerank>();
-        res->id    = slot.id_task;
-        res->index = slot.index;
-        res->n_tokens = slot.n_prompt_tokens;
+        res->id       = slot.task->id;
+        res->index    = slot.task->index;
+        res->n_tokens = slot.n_prompt_tokens();
 
         for (int i = 0; i < batch.n_tokens; ++i) {
             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -3034,7 +3318,7 @@ struct server_context {
             case SERVER_TASK_TYPE_EMBEDDING:
             case SERVER_TASK_TYPE_RERANK:
                 {
-                    const int id_slot = task.id_selected_slot;
+                    const int id_slot = task.id_slot;
 
                     server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
 
@@ -3061,7 +3345,7 @@ struct server_context {
                 {
                     // release slot linked with the task id
                     for (auto & slot : slots) {
-                        if (slot.id_task == task.id_target) {
+                        if (slot.task && slot.task->id == task.id_target) {
                             slot.release();
                             break;
                         }
@@ -3079,7 +3363,7 @@ struct server_context {
                     int n_processing_slots = 0;
 
                     for (server_slot & slot : slots) {
-                        json slot_data = slot.to_json(true);
+                        json slot_data = slot.to_json(slots_debug == 0);
 
                         if (slot.is_processing()) {
                             n_processing_slots++;
@@ -3121,7 +3405,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_SAVE:
                 {
-                    if (!ensure_no_mtmd(task.id)) {
+                    if (!check_no_mtmd(task.id)) {
                         break;
                     }
 
@@ -3138,13 +3422,13 @@ struct server_context {
                         break;
                     }
 
-                    const size_t token_count = slot->cache_tokens.size();
+                    const size_t token_count = slot->prompt.tokens.size();
                     const int64_t t_start = ggml_time_us();
 
                     std::string filename = task.slot_action.filename;
                     std::string filepath = task.slot_action.filepath;
 
-                    const llama_tokens & tokens = slot->cache_tokens.get_text_tokens();
+                    const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
                     const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
 
                     const int64_t t_end = ggml_time_us();
@@ -3162,7 +3446,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_RESTORE:
                 {
-                    if (!ensure_no_mtmd(task.id)) break;
+                    if (!check_no_mtmd(task.id)) break;
                     int id_slot = task.slot_action.slot_id;
                     server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
@@ -3186,13 +3470,13 @@ struct server_context {
                     size_t token_count = 0;
                     size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
                     if (nread == 0) {
-                        slot->cache_tokens.clear(); // KV may already been invalidated?
+                        slot->prompt.tokens.clear(); // KV may already been invalidated?
                         send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
                     tokens.resize(token_count);
-                    slot->cache_tokens.clear();
-                    slot->cache_tokens.insert(tokens);
+                    slot->prompt.tokens.clear();
+                    slot->prompt.tokens.insert(tokens);
 
                     const int64_t t_end = ggml_time_us();
                     const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -3209,7 +3493,9 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_ERASE:
                 {
-                    if (!ensure_no_mtmd(task.id)) break;
+                    if (!check_no_mtmd(task.id)) {
+                        break;
+                    }
                     int id_slot = task.slot_action.slot_id;
                     server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
@@ -3224,9 +3510,9 @@ struct server_context {
                     }
 
                     // Erase token cache
-                    const size_t n_erased = slot->cache_tokens.size();
+                    const size_t n_erased = slot->prompt.tokens.size();
                     llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
-                    slot->cache_tokens.clear();
+                    slot->prompt.tokens.clear();
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
                     res->id       = task.id;
@@ -3282,8 +3568,8 @@ struct server_context {
                 if (!params_base.ctx_shift) {
                     // this check is redundant (for good)
                     // we should never get here, because generation should already stopped in process_token()
-                    slot.release();
                     send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+                    slot.release();
                     continue;
                 }
 
@@ -3294,9 +3580,16 @@ struct server_context {
                 }
 
                 // Shift context
-                const int n_keep    = slot.params.n_keep + add_bos_token;
+                int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
+
+                if (add_bos_token) {
+                    n_keep += 1;
+                }
+
+                n_keep = std::min(slot.n_ctx - 4, n_keep);
+
                 const int n_left    = slot.n_past - n_keep;
-                const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
+                const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
 
                 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
@@ -3305,14 +3598,14 @@ struct server_context {
 
                 // add generated tokens to cache
                 {
-                    llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
+                    llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
                     for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
                         new_tokens[i - n_discard] = new_tokens[i];
                     }
 
-                    new_tokens.resize(slot.cache_tokens.size() - n_discard);
-                    slot.cache_tokens.clear();
-                    slot.cache_tokens.insert(new_tokens);
+                    new_tokens.resize(slot.prompt.tokens.size() - n_discard);
+                    slot.prompt.tokens.clear();
+                    slot.prompt.tokens.insert(new_tokens);
                 }
 
                 slot.n_past -= n_discard;
@@ -3328,7 +3621,8 @@ struct server_context {
         server_slot * slot_batched = nullptr;
 
         auto accept_special_token = [&](server_slot & slot, llama_token token) {
-            return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
+            return params_base.special ||
+                slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
         };
 
         // frist, add sampled tokens from any ongoing sequences
@@ -3349,10 +3643,10 @@ struct server_context {
             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
 
             slot.n_past += 1;
-            slot.cache_tokens.push_back(slot.sampled);
+            slot.prompt.tokens.push_back(slot.sampled);
 
             SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
-                    slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
+                    slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
         }
 
         // process in chunks of params.n_batch
@@ -3375,7 +3669,7 @@ struct server_context {
 
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
-                    auto & prompt_tokens = slot.prompt_tokens;
+                    const auto & input_tokens = slot.task->tokens;
 
                     // TODO: maybe move branch to outside of this loop in the future
                     if (slot.state == SLOT_STATE_STARTED) {
@@ -3383,104 +3677,64 @@ struct server_context {
                         slot.t_start_generation = 0;
 
                         slot.n_past = 0;
-                        slot.n_prompt_tokens = prompt_tokens.size();
                         slot.state = SLOT_STATE_PROCESSING_PROMPT;
 
-                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
+                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
+                                slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
 
                         // print prompt tokens (for debugging)
                         /*if (1) {
                             // first 16 tokens (avoid flooding logs)
-                            for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
-                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+                            for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
+                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
                             }
                         } else {
                             // all
-                            for (int i = 0; i < (int) prompt_tokens.size(); i++) {
-                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+                            for (int i = 0; i < (int) input_tokens.size(); i++) {
+                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
                             }
                         }*/
 
                         // empty prompt passed -> release the slot and send empty response
-                        if (prompt_tokens.empty()) {
+                        if (input_tokens.empty()) {
                             SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
 
-                            slot.release();
                             slot.print_timings();
                             send_final_response(slot);
+                            slot.release();
+
                             continue;
                         }
 
                         // TODO: support memory-less logits computation
                         if (slot.need_logits() && !llama_get_memory(ctx)) {
-                            slot.release();
                             send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
+                            slot.release();
                             continue;
                         }
 
                         if (!slot.can_split()) {
-                            if (slot.n_prompt_tokens > n_ubatch) {
-                                slot.release();
+                            if (slot.n_prompt_tokens() > n_ubatch) {
                                 send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
+                                slot.release();
                                 continue;
                             }
 
-                            if (slot.n_prompt_tokens > slot.n_ctx) {
-                                slot.release();
+                            if (slot.n_prompt_tokens() > slot.n_ctx) {
                                 send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+                                slot.release();
                                 continue;
                             }
                         } else {
-                            if (!params_base.ctx_shift) {
-                                // if context shift is disabled, we make sure prompt size is smaller than KV size
-                                // TODO: there should be a separate parameter that control prompt truncation
-                                //       context shift should be applied only during the generation phase
-                                if (slot.n_prompt_tokens >= slot.n_ctx) {
-                                    slot.release();
-                                    send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
-                                    continue;
-                                }
-                            }
-                            if (slot.params.n_keep < 0) {
-                                slot.params.n_keep = slot.n_prompt_tokens;
-                            }
-                            slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
-
-                            // if input prompt is too big, truncate it
-                            if (slot.n_prompt_tokens >= slot.n_ctx) {
-                                if (mctx) {
-                                    // we should never reach this
-                                    GGML_ABORT("not supported by multimodal");
-                                }
-                                const int n_left = slot.n_ctx - slot.params.n_keep;
-
-                                const int n_block_size = n_left / 2;
-                                const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
-
-                                const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens();
-                                llama_tokens new_tokens(
-                                        curr_tokens.begin(),
-                                        curr_tokens.begin() + slot.params.n_keep);
-
-                                new_tokens.insert(
-                                        new_tokens.end(),
-                                        curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
-                                        curr_tokens.end());
-
-                                prompt_tokens.clear();
-                                prompt_tokens.insert(new_tokens);
-
-                                slot.truncated = true;
-                                slot.n_prompt_tokens = prompt_tokens.size();
-
-                                SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
-
-                                GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
+                            if (slot.n_prompt_tokens() >= slot.n_ctx) {
+                                send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+                                slot.release();
+                                continue;
                             }
 
-                            if (slot.params.cache_prompt) {
+                            if (slot.task->params.cache_prompt) {
                                 // reuse any previously computed tokens that are common with the new prompt
-                                slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
+                                slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
 
                                 // if there is an alora invoked, don't cache after the invocation start
                                 if (slot.alora_invocation_start >= 0) {
@@ -3500,13 +3754,13 @@ struct server_context {
 
                                     SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
 
-                                    while (head_c < slot.cache_tokens.size() &&
-                                           head_p < prompt_tokens.size()) {
+                                    while (head_c < slot.prompt.tokens.size() &&
+                                           head_p < input_tokens.size()) {
 
                                         size_t n_match = 0;
-                                        while (head_c + n_match < slot.cache_tokens.size() &&
-                                               head_p + n_match < prompt_tokens.size()     &&
-                                               slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
+                                        while (head_c + n_match < slot.prompt.tokens.size() &&
+                                               head_p + n_match < input_tokens.size()     &&
+                                               slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
 
                                             n_match++;
                                         }
@@ -3523,7 +3777,7 @@ struct server_context {
                                             llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
 
                                             for (size_t i = 0; i < n_match; i++) {
-                                                slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]);
+                                                slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
                                                 slot.n_past++;
                                             }
 
@@ -3547,41 +3801,83 @@ struct server_context {
                             // the largest pos_min required for a checkpoint to be useful
                             const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
 
-                            if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
+                            if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                                 if (pos_min == -1) {
-                                    SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+                                    SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
                                     GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
                                 }
 
+                                // when the prompt prefix does not match, print the tokens around the mismatch
+                                // this is useful for debugging prompt caching
+                                {
+                                    const int np0 = std::max<int>(slot.n_past - 4, 0);
+                                    const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
+
+                                    std::stringstream ss0;
+                                    std::stringstream ss1;
+
+                                    std::stringstream st0;
+                                    std::stringstream st1;
+
+                                    ss0 << "old: ... ";
+                                    ss1 << "new: ... ";
+
+                                    for (int i = np0; i < np1; i++) {
+                                        if (i == slot.n_past) {
+                                            ss0 << " | ";
+                                            ss1 << " | ";
+                                        }
+
+                                        {
+                                            const auto token = slot.prompt.tokens[i];
+                                            const auto piece = common_token_to_piece(ctx, token);
+                                            ss0 << piece;
+                                            st0 << std::setw(8) << token;
+                                        }
+
+                                        {
+                                            const auto token = slot.task->tokens[i];
+                                            const auto piece = common_token_to_piece(ctx, token);
+                                            ss1 << piece;
+                                            st1 << std::setw(8) << token;
+                                        }
+                                    }
+
+                                    SLT_WRN(slot, "%s\n", ss0.str().c_str());
+                                    SLT_WRN(slot, "%s\n", ss1.str().c_str());
+
+                                    SLT_WRN(slot, "%s\n", st0.str().c_str());
+                                    SLT_WRN(slot, "%s\n", st1.str().c_str());
+                                }
+
                                 if (pos_min > pos_min_thold) {
-                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
+                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
 
                                     // search for a context checkpoint
                                     const auto it = std::find_if(
-                                        slot.ctx_checkpoints.rbegin(),
-                                        slot.ctx_checkpoints.rend(),
+                                        slot.prompt.checkpoints.rbegin(),
+                                        slot.prompt.checkpoints.rend(),
                                         [&](const auto & cur) {
                                             // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
                                             return cur.pos_min < pos_min_thold;
                                         }
                                     );
 
-                                    bool do_reset = it == slot.ctx_checkpoints.rend();
-                                    //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
+                                    bool do_reset = it == slot.prompt.checkpoints.rend();
 
                                     if (!do_reset) {
                                         // restore the context checkpoint
-                                        const size_t ctx_checkpoint_size = it->data.size();
-                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+                                        const size_t checkpoint_size = it->data.size();
+                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                                        if (n != ctx_checkpoint_size) {
-                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
+                                        if (n != checkpoint_size) {
+                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
                                             do_reset = true;
                                             //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
                                         } else {
                                             slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
-                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
+                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
                                         }
                                     }
 
@@ -3595,19 +3891,21 @@ struct server_context {
 
                             {
                                 // erase any checkpoints with pos_min > pos_min_thold
-                                for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
-                                    const auto & cur = slot.ctx_checkpoints[i];
+                                for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
+                                    const auto & cur = *it;
                                     if (cur.pos_min > pos_min_thold) {
                                         SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
-                                        slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
+                                        it = slot.prompt.checkpoints.erase(it);
+                                    } else {
+                                        ++it;
                                     }
                                 }
                             }
                         }
 
                         // [TAG_PROMPT_LOGITS]
-                        if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
-                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
+                        if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
+                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
                             slot.n_past--;
                             SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
                         }
@@ -3618,7 +3916,7 @@ struct server_context {
 
                     if (!slot.can_split()) {
                         // cannot fit the prompt in the current batch - will try next iter
-                        if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
+                        if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
                             continue;
                         }
                     }
@@ -3636,28 +3934,28 @@ struct server_context {
                     SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
 
                     // remove the non-common part from the cache
-                    slot.cache_tokens.keep_first(slot.n_past);
+                    slot.prompt.tokens.keep_first(slot.n_past);
 
                     // check if we should process the image
-                    if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+                    if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
                         // process the image
                         int32_t new_n_past;
-                        int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
-                        int32_t n_pos = new_n_past - slot.n_past;
-
+                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
                         if (res != 0) {
                             SLT_ERR(slot, "failed to process image, res = %d\n", res);
-                            slot.release();
                             send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
+                            slot.release();
                             continue;
                         }
 
                         // add the image chunk to cache
                         {
-                            const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
-                            slot.cache_tokens.push_back(chunk.get()); // copy
+                            const auto & chunk = input_tokens.find_chunk(slot.n_past);
+                            slot.prompt.tokens.push_back(chunk.get()); // copy
                         }
 
+                        const int32_t n_pos = new_n_past - slot.n_past;
+
                         slot.n_past                    += n_pos;
                         slot.n_prompt_tokens_processed += n_pos;
                     }
@@ -3678,6 +3976,9 @@ struct server_context {
 
                     bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
 
+                    // make checkpoints only for completion tasks
+                    do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
+
                     // make a checkpoint of the parts of the memory that cannot be rolled back.
                     // checkpoints are created only if:
                     // - the model uses SWA and we are not using `swa_full`
@@ -3691,9 +3992,9 @@ struct server_context {
                             );
 
                     // add prompt tokens for processing in the current batch
-                    while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
+                    while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
                         // get next token to process
-                        llama_token cur_tok = slot.prompt_tokens[slot.n_past];
+                        llama_token cur_tok = input_tokens[slot.n_past];
                         if (cur_tok == LLAMA_TOKEN_NULL) {
                             break; // end of text chunk
                         }
@@ -3707,36 +4008,33 @@ struct server_context {
                         }
 
                         // embedding requires all tokens in the batch to be output
-                        const bool need_embd = server_task_type_need_embd(slot.task_type);
-
-                        common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
-                        slot.cache_tokens.push_back(cur_tok);
+                        common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
+                        slot.prompt.tokens.push_back(cur_tok);
 
                         slot.n_prompt_tokens_processed++;
                         slot.n_past++;
 
                         // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
-                        if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
+                        if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
                             break;
                         }
                     }
 
                     // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
 
-                    SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
+                    SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
 
                     // entire prompt has been processed
-                    if (slot.n_past == slot.n_prompt_tokens) {
+                    if (slot.n_past == slot.n_prompt_tokens()) {
                         slot.state = SLOT_STATE_DONE_PROMPT;
 
                         GGML_ASSERT(batch.n_tokens > 0);
-                        GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size());
 
                         common_sampler_reset(slot.smpl);
 
                         // Process all prompt tokens through sampler system
-                        for (int i = 0; i < slot.n_prompt_tokens; ++i) {
-                            llama_token id = slot.prompt_tokens[i];
+                        for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
+                            llama_token id = input_tokens[i];
                             if (id != LLAMA_TOKEN_NULL) {
                                 common_sampler_accept(slot.smpl, id, false);
                             }
@@ -3757,21 +4055,22 @@ struct server_context {
                         do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
 
                         // no need to create checkpoints that are too close together
-                        do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
+                        do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
 
                         if (do_checkpoint) {
-                            while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+                            while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
                                 // make room for the new checkpoint, if needed
-                                const auto & cur = slot.ctx_checkpoints.front();
+                                const auto & cur = slot.prompt.checkpoints.front();
+
                                 SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
                                         cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
 
-                                slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
+                                slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
                             }
 
                             const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                            auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
+                            auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
                                 /*.pos_min = */ pos_min,
                                 /*.pos_max = */ pos_max,
                                 /*.data    = */ std::vector<uint8_t>(checkpoint_size),
@@ -3779,8 +4078,8 @@ struct server_context {
 
                             llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                            SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
-                                    (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                            SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+                                    (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
                         }
                     }
                 }
@@ -3854,8 +4153,8 @@ struct server_context {
                     if (!err.empty()) {
                         SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
                         for (auto & slot : slots) {
-                            slot.release();
                             send_error(slot, err);
+                            slot.release();
                         }
                         break;
                     }
@@ -3878,7 +4177,7 @@ struct server_context {
             for (auto & slot : slots) {
                 // optionally send prompt processing progress
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
-                    if (slot.params.stream && slot.params.return_progress) {
+                    if (slot.task->params.stream && slot.task->params.return_progress) {
                         send_partial_response(slot, {}, true);
                     }
                 }
@@ -3888,7 +4187,7 @@ struct server_context {
                 }
 
                 if (slot.state == SLOT_STATE_DONE_PROMPT) {
-                    if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
+                    if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
                         // prompt evaluated for embedding
                         send_embedding(slot, batch_view);
                         slot.release();
@@ -3896,7 +4195,7 @@ struct server_context {
                         continue; // continue loop of slots
                     }
 
-                    if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
+                    if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
                         send_rerank(slot, batch_view);
                         slot.release();
                         slot.i_batch = -1;
@@ -3934,16 +4233,17 @@ struct server_context {
                 result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
                 result.prob         = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
 
-                if (slot.params.sampling.n_probs > 0) {
-                    populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
+                if (slot.task->params.sampling.n_probs > 0) {
+                    populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
                 }
 
                 if (!process_token(result, slot)) {
                     // release slot because of stop condition
-                    slot.release();
                     slot.print_timings();
                     send_final_response(slot);
                     metrics.on_prediction(slot);
+                    slot.release();
+
                     continue;
                 }
             }
@@ -3964,7 +4264,7 @@ struct server_context {
                 }
 
                 // determine the max draft that fits the current slot state
-                int n_draft_max = slot.params.speculative.n_max;
+                int n_draft_max = slot.task->params.speculative.n_max;
 
                 // note: n_past is not yet increased for the `id` token sampled above
                 //       also, need to leave space for 1 extra token to allow context shifts
@@ -3976,8 +4276,8 @@ struct server_context {
 
                 SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
 
-                if (n_draft_max < slot.params.speculative.n_min) {
-                    SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
+                if (n_draft_max < slot.task->params.speculative.n_min) {
+                    SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
 
                     continue;
                 }
@@ -3985,16 +4285,16 @@ struct server_context {
                 llama_token id = slot.sampled;
 
                 struct common_speculative_params params_spec;
-                params_spec.n_draft   = n_draft_max;
-                params_spec.n_reuse   = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
-                params_spec.p_min     = slot.params.speculative.p_min;
+                params_spec.n_draft = n_draft_max;
+                params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
+                params_spec.p_min   = slot.task->params.speculative.p_min;
 
-                const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
+                const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
                 llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
 
                 // ignore small drafts
-                if (slot.params.speculative.n_min > (int) draft.size()) {
-                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
+                if (slot.task->params.speculative.n_min > (int) draft.size()) {
+                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
 
                     continue;
                 }
@@ -4023,8 +4323,8 @@ struct server_context {
                 // update how many tokens out of those tested were accepted
                 slot.n_draft_accepted += ids.size() - 1;
 
-                slot.cache_tokens.push_back(id);
-                slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
+                slot.prompt.tokens.push_back(id);
+                slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
 
                 llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
 
@@ -4038,11 +4338,11 @@ struct server_context {
                     // TODO: set result.probs
 
                     if (!process_token(result, slot)) {
-                        // release slot because of stop condition
-                        slot.release();
                         slot.print_timings();
                         send_final_response(slot);
                         metrics.on_prediction(slot);
+                        slot.release();
+
                         break;
                     }
                 }
@@ -4310,18 +4610,18 @@ int main(int argc, char ** argv) {
         }
 
         // TODO: get rid of this dynamic_cast
-        auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
-        GGML_ASSERT(res_metrics != nullptr);
+        auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
+        GGML_ASSERT(res_task != nullptr);
 
         // optionally return "fail_on_no_slot" error
         if (req.has_param("fail_on_no_slot")) {
-            if (res_metrics->n_idle_slots == 0) {
+            if (res_task->n_idle_slots == 0) {
                 res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
                 return;
             }
         }
 
-        res_ok(res, res_metrics->slots_data);
+        res_ok(res, res_task->slots_data);
     };
 
     const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
@@ -4349,56 +4649,56 @@ int main(int argc, char ** argv) {
         }
 
         // TODO: get rid of this dynamic_cast
-        auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
-        GGML_ASSERT(res_metrics != nullptr);
+        auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
+        GGML_ASSERT(res_task != nullptr);
 
         // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
         json all_metrics_def = json {
             {"counter", {{
                     {"name",  "prompt_tokens_total"},
                     {"help",  "Number of prompt tokens processed."},
-                    {"value",  (uint64_t) res_metrics->n_prompt_tokens_processed_total}
+                    {"value",  (uint64_t) res_task->n_prompt_tokens_processed_total}
             }, {
                     {"name",  "prompt_seconds_total"},
                     {"help",  "Prompt process time"},
-                    {"value",  (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
+                    {"value",  (uint64_t) res_task->t_prompt_processing_total / 1.e3}
             }, {
                     {"name",  "tokens_predicted_total"},
                     {"help",  "Number of generation tokens processed."},
-                    {"value",  (uint64_t) res_metrics->n_tokens_predicted_total}
+                    {"value",  (uint64_t) res_task->n_tokens_predicted_total}
             }, {
                     {"name",  "tokens_predicted_seconds_total"},
                     {"help",  "Predict process time"},
-                    {"value",  (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
+                    {"value",  (uint64_t) res_task->t_tokens_generation_total / 1.e3}
             }, {
                     {"name",  "n_decode_total"},
                     {"help",  "Total number of llama_decode() calls"},
-                    {"value",  res_metrics->n_decode_total}
+                    {"value",  res_task->n_decode_total}
             }, {
                     {"name",  "n_past_max"},
                     {"help",  "Largest observed n_past."},
-                    {"value",  res_metrics->n_past_max}
+                    {"value",  res_task->n_past_max}
             }, {
                     {"name",  "n_busy_slots_per_decode"},
                     {"help",  "Average number of busy slots per llama_decode() call"},
-                    {"value",  (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)}
+                    {"value",  (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
             }}},
             {"gauge", {{
                     {"name",  "prompt_tokens_seconds"},
                     {"help",  "Average prompt throughput in tokens/s."},
-                    {"value",  res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
+                    {"value",  res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
             },{
                     {"name",  "predicted_tokens_seconds"},
                     {"help",  "Average generation throughput in tokens/s."},
-                    {"value",  res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
+                    {"value",  res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
             },{
                     {"name",  "requests_processing"},
                     {"help",  "Number of requests processing."},
-                    {"value",  (uint64_t) res_metrics->n_processing_slots}
+                    {"value",  (uint64_t) res_task->n_processing_slots}
             },{
                     {"name",  "requests_deferred"},
                     {"help",  "Number of requests deferred."},
-                    {"value",  (uint64_t) res_metrics->n_tasks_deferred}
+                    {"value",  (uint64_t) res_task->n_tasks_deferred}
             }}}
         };
 
@@ -4419,7 +4719,7 @@ int main(int argc, char ** argv) {
             }
         }
 
-        res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
+        res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start));
 
         res.set_content(prometheus.str(), "text/plain; version=0.0.4");
         res.status = 200; // HTTP OK
@@ -4543,9 +4843,22 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_props = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
+        json default_generation_settings_for_props;
+
+        {
+            slot_params params;
+
+            params.sampling = ctx_server.params_base.sampling;
+
+            default_generation_settings_for_props = json {
+                {"params", params.to_json(true)},
+                {"n_ctx",  ctx_server.slots[0].n_ctx},
+            };
+        }
+
         // this endpoint is publicly available, please only return what is safe to be exposed
         json data = {
-            { "default_generation_settings", ctx_server.default_generation_settings_for_props },
+            { "default_generation_settings", default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model.path },
             { "modalities",                  json {
@@ -4650,12 +4963,12 @@ int main(int argc, char ** argv) {
                 task.id    = ctx_server.queue_tasks.get_new_id();
                 task.index = i;
 
-                task.prompt_tokens    = std::move(inputs[i]);
-                task.params           = server_task::params_from_json_cmpl(
+                task.tokens = std::move(inputs[i]);
+                task.params = server_task::params_from_json_cmpl(
                         ctx_server.ctx,
                         ctx_server.params_base,
                         data);
-                task.id_selected_slot = json_value(data, "id_slot", -1);
+                task.id_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat
                 task.params.oaicompat                 = oaicompat;
@@ -5024,9 +5337,9 @@ int main(int argc, char ** argv) {
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
                 server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
 
-                task.id            = ctx_server.queue_tasks.get_new_id();
-                task.index         = i;
-                task.prompt_tokens = std::move(tokenized_prompts[i]);
+                task.id     = ctx_server.queue_tasks.get_new_id();
+                task.index  = i;
+                task.tokens = std::move(tokenized_prompts[i]);
 
                 // OAI-compat
                 task.params.oaicompat = oaicompat;
@@ -5122,10 +5435,10 @@ int main(int argc, char ** argv) {
             tasks.reserve(documents.size());
             for (size_t i = 0; i < documents.size(); i++) {
                 auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
-                server_task task   = server_task(SERVER_TASK_TYPE_RERANK);
-                task.id            = ctx_server.queue_tasks.get_new_id();
-                task.index         = i;
-                task.prompt_tokens = std::move(tmp);
+                server_task task = server_task(SERVER_TASK_TYPE_RERANK);
+                task.id     = ctx_server.queue_tasks.get_new_id();
+                task.index  = i;
+                task.tokens = std::move(tmp);
                 tasks.push_back(std::move(task));
             }
 
@@ -5383,7 +5696,7 @@ int main(int argc, char ** argv) {
 #endif
 
     LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__,
-            is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
+            is_sock ? string_format("unix://%s",    params.hostname.c_str()).c_str() :
                       string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str());
 
     // this call blocks the main thread until queue_tasks.terminate() is called
index 829af2ebe7bfb359a59685b91135b189e3f40a37..720b136b051750955f2f428a55cc56a6b37161aa 100644 (file)
@@ -66,8 +66,7 @@ def test_server_slots():
     assert len(res.body) == server.n_slots
     assert server.n_ctx is not None and server.n_slots is not None
     assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
-    assert "params" in res.body[0]
-    assert res.body[0]["params"]["seed"] == server.seed
+    assert "params" not in res.body[0]
 
 
 def test_load_split_model():
index 2979ed4bb7b1268c852503fa0f9dac627a4d6b79..6e5a3488e789bac81c50c5684d84bce1fbaa73ef 100644 (file)
@@ -19,8 +19,8 @@ def create_server():
         (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True,  None),
         (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
         (None, "Book", "What is the best book", 8, "^ blue",                    23, 8, "length", True, "This is not a chat template, it is"),
-        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
-        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
+        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
+        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
         (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
         (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
     ]
@@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
     "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
     [
         ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
-        ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
+        ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
     ]
 )
 def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
index 11483e679a505f8de994e557064be8a553fa3ded..00ba78cf67c0935c71ea5dcb3909fd1e2285e91b 100644 (file)
@@ -16,7 +16,7 @@ def create_server():
 
 @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
     ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
-    ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
+    ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
 ])
 def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
     global server
@@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
 
 @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
     ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
-    ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
+    ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
 ])
 def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
     global server
index 92e49f2bb05a48b1825796fa4d9bd14be90c61d0..4adbbde64f5947239a2b11b2ceb19b78681b3f71 100644 (file)
@@ -4,6 +4,12 @@ from utils import *
 server = ServerPreset.tinyllama2()
 
 
+SHORT_TEXT = """
+Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+""".strip()
+
 LONG_TEXT = """
 Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
 Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
@@ -21,19 +27,18 @@ def create_server():
 
 
 def test_ctx_shift_enabled():
-    # the prompt is 301 tokens
+    # the prompt is 226 tokens
     # the slot context is 512/2 = 256 tokens
-    # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
     # 96 tokens are generated thanks to shifting the context when it gets full
     global server
     server.enable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
         "n_predict": 96,
-        "prompt": LONG_TEXT,
+        "prompt": SHORT_TEXT,
     })
     assert res.status_code == 200
-    assert res.body["timings"]["prompt_n"] == 173
+    assert res.body["timings"]["prompt_n"] == 226
     assert res.body["timings"]["predicted_n"] == 96
     assert res.body["truncated"] is True
 
index 4ca1423aaf2d4bac5b4938cf08b3692e5ad14a6b..f175115f4fd6aaa85f462e471c528fbe6e931e69 100644 (file)
 
 using json = nlohmann::ordered_json;
 
-#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
 
 #define SRV_INF(fmt, ...) LOG_INF("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 #define SRV_WRN(fmt, ...) LOG_WRN("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
@@ -1102,6 +1102,7 @@ public:
     ~server_tokens() = default;
 
     // Prevent copying
+    // TODO: server_tokens should be copyable - remove this:
     server_tokens(const server_tokens&) = delete;
     server_tokens& operator=(const server_tokens&) = delete;
 
@@ -1119,7 +1120,7 @@ public:
         }
     }
 
-    server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
+    server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
 
     // for debugging
     std::string str() const {
@@ -1144,9 +1145,8 @@ public:
         auto it = map_pos_to_media.find(pos);
         if (it != map_pos_to_media.end()) {
             return it->second;
-        } else {
-            throw std::runtime_error("Chunk not found");
         }
+        throw std::runtime_error("Chunk not found");
     }
 
     void push_back(llama_token tok) {
@@ -1170,7 +1170,7 @@ public:
             map_pos_to_media[start_pos] = std::move(new_chunk);
         } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             size_t n_tokens;
-            auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+            const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
             for (size_t i = 0; i < n_tokens; ++i) {
                 push_back(text_tokens[i]);
             }
@@ -1190,7 +1190,7 @@ public:
             // We could also just check, but this will prevent silently dropping MTMD data.
             GGML_ASSERT(has_mtmd);
             for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
-                auto chunk = tokens.map_pos_to_media[it->first].get();
+                auto chunk = tokens.map_pos_to_media[it->first].get();
                 mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
                 map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
             }
@@ -1271,33 +1271,52 @@ public:
     }
 
     size_t get_common_prefix(const server_tokens & b) const {
-        size_t max_idx = std::min(tokens.size(), b.tokens.size());
+        const size_t max_idx = std::min(tokens.size(), b.tokens.size());
+
+        if (!has_mtmd) {
+            for (size_t i = 0; i < max_idx; ++i) {
+                if (tokens[i] == b.tokens[i]) {
+                    continue;
+                }
+
+                return i;
+            }
+
+            return max_idx;
+        }
+
         for (size_t i = 0; i < max_idx; ++i) {
-            auto & ai =   tokens[i];
-            auto & bi = b.tokens[i];
+            const llama_token ai =   tokens[i];
+            const llama_token bi = b.tokens[i];
 
             if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
-                GGML_ASSERT(has_mtmd);
                 const auto & a_chunk =   find_chunk(i);
                 const auto & b_chunk = b.find_chunk(i);
+
                 GGML_ASSERT(a_chunk && b_chunk);
-                std::string ai_id  = mtmd_input_chunk_get_id(a_chunk.get());
-                std::string bi_id  = mtmd_input_chunk_get_id(b_chunk.get());
-                size_t a_pos       = mtmd_input_chunk_get_n_pos(a_chunk.get());
-                size_t b_pos       = mtmd_input_chunk_get_n_pos(b_chunk.get());
-                if (ai_id == bi_id && a_pos == b_pos) {
-                    GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
-                    i += a_pos - 1; // will be +1 by the for loop
+
+                const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
+                const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
+
+                const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
+                const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
+
+                if (id_ai == id_bi && pos_a == pos_b) {
+                    GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
+                    i += pos_a - 1; // will be +1 by the for loop
                     continue;
-                } else {
-                    return i;
                 }
-            } else if (ai == bi) {
-                continue;
-            } else {
+
                 return i;
             }
+
+            if (ai == bi) {
+                continue;
+            }
+
+            return i;
         }
+
         return max_idx; // all tokens are equal
     }
 
@@ -1308,7 +1327,7 @@ public:
         const int32_t n_vocab = llama_vocab_n_tokens(vocab);
 
         for (size_t i = 0; i < tokens.size(); ++i) {
-            auto & t = tokens[i];
+            const auto & t = tokens[i];
             if (t == LLAMA_TOKEN_NULL) {
                 try {
                     const auto & chunk = find_chunk(i);
@@ -1330,8 +1349,8 @@ public:
                 mtmd_context * mctx,
                 llama_pos n_past,
                 int32_t seq_id,
-                llama_pos & n_pos_out) {
-        auto & chunk = find_chunk(n_past);
+                llama_pos & n_pos_out) const {
+        const auto & chunk = find_chunk(n_past);
         const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
                             ? "image" : "audio";
         SRV_INF("processing %s...\n", name);