]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add speculative decoding support (#10455)
authorGeorgi Gerganov <redacted>
Mon, 25 Nov 2024 14:31:38 +0000 (16:31 +0200)
committerGitHub <redacted>
Mon, 25 Nov 2024 14:31:38 +0000 (16:31 +0200)
* server : add speculative decoding support

ggml-ci

* server : add helper function slot.can_speculate()

ggml-ci

examples/server/server.cpp

index 6c55d65c013305eecddbaad34c5fafd6ba20208d..f9d20fee5a6f669466949b72d1954ce67170efbe 100644 (file)
@@ -2,10 +2,11 @@
 
 #include "arg.h"
 #include "common.h"
-#include "log.h"
-#include "sampling.h"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
+#include "log.h"
+#include "sampling.h"
+#include "speculative.h"
 
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
@@ -121,12 +122,21 @@ struct slot_params {
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
     std::vector<std::string> antiprompt;
+
+    struct common_params_sampling sampling;
+    struct common_params_speculative speculative;
 };
 
 struct server_slot {
     int id;
     int id_task = -1;
 
+    llama_batch batch_spec;
+
+    llama_context * ctx_dft = nullptr;
+
+    common_speculative * spec = nullptr;
+
     // the index relative to completion multi-task request
     size_t index = 0;
 
@@ -175,7 +185,6 @@ struct server_slot {
     // sampling
     json json_schema;
 
-    struct common_params_sampling sparams;
     struct common_sampler * smpl = nullptr;
 
     llama_token sampled;
@@ -212,7 +221,7 @@ struct server_slot {
         generated_token_probs.clear();
     }
 
-    bool has_budget(common_params &global_params) {
+    bool has_budget(const common_params & global_params) {
         if (params.n_predict == -1 && global_params.n_predict == -1) {
             return true; // limitless
         }
@@ -232,6 +241,10 @@ struct server_slot {
         return state != SLOT_STATE_IDLE;
     }
 
+    bool can_speculate() const {
+        return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+    }
+
     void add_token(const completion_token_output & token) {
         if (!is_processing()) {
             SLT_WRN(*this, "%s", "slot is not processing\n");
@@ -591,11 +604,14 @@ struct server_response {
 };
 
 struct server_context {
+    common_params params_base;
+
     llama_model * model = nullptr;
     llama_context * ctx = nullptr;
     std::vector<common_lora_adapter_container> loras;
 
-    common_params params;
+    llama_model * model_dft = nullptr;
+    llama_context_params cparams_dft;
 
     llama_batch batch = {};
 
@@ -628,27 +644,41 @@ struct server_context {
             model = nullptr;
         }
 
+        if (model_dft) {
+            llama_free_model(model_dft);
+            model_dft = nullptr;
+        }
+
         // Clear any sampling context
         for (server_slot & slot : slots) {
-            if (slot.smpl != nullptr) {
-                common_sampler_free(slot.smpl);
-            }
+            common_sampler_free(slot.smpl);
+            slot.smpl = nullptr;
+
+            llama_free(slot.ctx_dft);
+            slot.ctx_dft = nullptr;
+
+            common_speculative_free(slot.spec);
+            slot.spec = nullptr;
+
+            llama_batch_free(slot.batch_spec);
         }
 
         llama_batch_free(batch);
     }
 
-    bool load_model(const common_params & params_) {
-        params = params_;
+    bool load_model(const common_params & params) {
+        SRV_INF("loading model '%s'\n", params.model.c_str());
 
-        common_init_result llama_init = common_init_from_params(params);
+        params_base = params;
+
+        common_init_result llama_init = common_init_from_params(params_base);
 
         model = llama_init.model;
         ctx   = llama_init.context;
         loras = llama_init.lora_adapters;
 
         if (model == nullptr) {
-            SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
+            SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
             return false;
         }
 
@@ -657,6 +687,40 @@ struct server_context {
         add_bos_token = llama_add_bos_token(model);
         has_eos_token = !llama_add_eos_token(model);
 
+        if (!params_base.speculative.model.empty()) {
+            SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
+
+            auto params_dft = params_base;
+
+            params_dft.model        = params_base.speculative.model;
+            params_dft.n_ctx        = params_base.speculative.n_ctx;
+            params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
+
+            common_init_result llama_init_dft = common_init_from_params(params_dft);
+
+            model_dft = llama_init_dft.model;
+
+            if (model_dft == nullptr) {
+                SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
+                return false;
+            }
+
+            if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
+                SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
+
+                llama_free      (llama_init_dft.context);
+                llama_free_model(llama_init_dft.model);
+
+                return false;
+            }
+
+            cparams_dft = common_context_params_to_llama(params_base);
+            cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
+
+            // the context is not needed - we will create one for each slot
+            llama_free(llama_init_dft.context);
+        }
+
         return true;
     }
 
@@ -674,20 +738,36 @@ struct server_context {
     }
 
     void init() {
-        const int32_t n_ctx_slot = n_ctx / params.n_parallel;
+        const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
 
-        SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
+        SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
 
-        for (int i = 0; i < params.n_parallel; i++) {
+        for (int i = 0; i < params_base.n_parallel; i++) {
             server_slot slot;
 
             slot.id = i;
             slot.n_ctx = n_ctx_slot;
-            slot.n_predict = params.n_predict;
+            slot.n_predict = params_base.n_predict;
+
+            if (model_dft) {
+                slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
+
+                slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
+                if (slot.ctx_dft == nullptr) {
+                    SRV_ERR("%s", "failed to create draft context\n");
+                    return;
+                }
+
+                slot.spec = common_speculative_init(slot.ctx_dft);
+                if (slot.spec == nullptr) {
+                    SRV_ERR("%s", "failed to create speculator\n");
+                    return;
+                }
+            }
 
             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
 
-            slot.sparams = params.sampling;
+            slot.params.sampling = params_base.sampling;
 
             slot.callback_on_release = [this](int) {
                 queue_tasks.pop_deferred_task();
@@ -707,7 +787,7 @@ struct server_context {
             const int32_t n_batch = llama_n_batch(ctx);
 
             // only a single seq_id per token is needed
-            batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
+            batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
         }
 
         metrics.init();
@@ -786,9 +866,11 @@ struct server_context {
     }
 
     bool launch_slot_with_task(server_slot & slot, const server_task & task) {
-        slot_params default_params;
         // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
-        auto default_sparams = params.sampling;
+        slot_params defaults;
+        defaults.sampling    = params_base.sampling;
+        defaults.speculative = params_base.speculative;
+
         const auto & data = task.data;
 
         if (data.count("__oaicompat") != 0) {
@@ -799,42 +881,48 @@ struct server_context {
             slot.oaicompat_model = "";
         }
 
-        slot.params.stream              = json_value(data, "stream",             false);
-        slot.params.cache_prompt        = json_value(data, "cache_prompt",       false);
-        slot.params.n_predict           = json_value(data, "n_predict",          json_value(data, "max_tokens", default_params.n_predict));
-        slot.params.n_indent            = json_value(data, "n_indent",           default_params.n_indent);
-        slot.sparams.top_k              = json_value(data, "top_k",              default_sparams.top_k);
-        slot.sparams.top_p              = json_value(data, "top_p",              default_sparams.top_p);
-        slot.sparams.min_p              = json_value(data, "min_p",              default_sparams.min_p);
-        slot.sparams.xtc_probability    = json_value(data, "xtc_probability",    default_sparams.xtc_probability);
-        slot.sparams.xtc_threshold      = json_value(data, "xtc_threshold",      default_sparams.xtc_threshold);
-        slot.sparams.typ_p              = json_value(data, "typical_p",          default_sparams.typ_p);
-        slot.sparams.temp               = json_value(data, "temperature",        default_sparams.temp);
-        slot.sparams.dynatemp_range     = json_value(data, "dynatemp_range",     default_sparams.dynatemp_range);
-        slot.sparams.dynatemp_exponent  = json_value(data, "dynatemp_exponent",  default_sparams.dynatemp_exponent);
-        slot.sparams.penalty_last_n     = json_value(data, "repeat_last_n",      default_sparams.penalty_last_n);
-        slot.sparams.penalty_repeat     = json_value(data, "repeat_penalty",     default_sparams.penalty_repeat);
-        slot.sparams.penalty_freq       = json_value(data, "frequency_penalty",  default_sparams.penalty_freq);
-        slot.sparams.penalty_present    = json_value(data, "presence_penalty",   default_sparams.penalty_present);
-        slot.sparams.dry_multiplier     = json_value(data, "dry_multiplier",     default_sparams.dry_multiplier);
-        slot.sparams.dry_base           = json_value(data, "dry_base",           default_sparams.dry_base);
-        slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
-        slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
-        slot.sparams.mirostat           = json_value(data, "mirostat",           default_sparams.mirostat);
-        slot.sparams.mirostat_tau       = json_value(data, "mirostat_tau",       default_sparams.mirostat_tau);
-        slot.sparams.mirostat_eta       = json_value(data, "mirostat_eta",       default_sparams.mirostat_eta);
-        slot.sparams.penalize_nl        = json_value(data, "penalize_nl",        default_sparams.penalize_nl);
-        slot.params.n_keep              = json_value(data, "n_keep",             default_params.n_keep);
-        slot.params.n_discard           = json_value(data, "n_discard",          default_params.n_discard);
-        slot.sparams.seed               = json_value(data, "seed",               default_sparams.seed);
-        slot.sparams.n_probs            = json_value(data, "n_probs",            default_sparams.n_probs);
-        slot.sparams.min_keep           = json_value(data, "min_keep",           default_sparams.min_keep);
-      //slot.params.t_max_prompt_ms     = json_value(data, "t_max_prompt_ms",    default_params.t_max_prompt_ms); // TODO: implement
-        slot.params.t_max_predict_ms    = json_value(data, "t_max_predict_ms",   default_params.t_max_predict_ms);
-
-        if (slot.sparams.dry_base < 1.0f)
-        {
-           slot.sparams.dry_base = default_sparams.dry_base;
+        slot.params.stream           = json_value(data, "stream",             false);
+        slot.params.cache_prompt     = json_value(data, "cache_prompt",       false);
+        slot.params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
+        slot.params.n_indent         = json_value(data, "n_indent",           defaults.n_indent);
+        slot.params.n_keep           = json_value(data, "n_keep",             defaults.n_keep);
+        slot.params.n_discard        = json_value(data, "n_discard",          defaults.n_discard);
+      //slot.params.t_max_prompt_ms  = json_value(data, "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
+        slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms",   defaults.t_max_predict_ms);
+
+        slot.params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
+        slot.params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
+        slot.params.sampling.min_p              = json_value(data, "min_p",              defaults.sampling.min_p);
+        slot.params.sampling.xtc_probability    = json_value(data, "xtc_probability",    defaults.sampling.xtc_probability);
+        slot.params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",      defaults.sampling.xtc_threshold);
+        slot.params.sampling.typ_p              = json_value(data, "typical_p",          defaults.sampling.typ_p);
+        slot.params.sampling.temp               = json_value(data, "temperature",        defaults.sampling.temp);
+        slot.params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",     defaults.sampling.dynatemp_range);
+        slot.params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",  defaults.sampling.dynatemp_exponent);
+        slot.params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",      defaults.sampling.penalty_last_n);
+        slot.params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",     defaults.sampling.penalty_repeat);
+        slot.params.sampling.penalty_freq       = json_value(data, "frequency_penalty",  defaults.sampling.penalty_freq);
+        slot.params.sampling.penalty_present    = json_value(data, "presence_penalty",   defaults.sampling.penalty_present);
+        slot.params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",     defaults.sampling.dry_multiplier);
+        slot.params.sampling.dry_base           = json_value(data, "dry_base",           defaults.sampling.dry_base);
+        slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+        slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+        slot.params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
+        slot.params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
+        slot.params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
+        slot.params.sampling.penalize_nl        = json_value(data, "penalize_nl",        defaults.sampling.penalize_nl);
+        slot.params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
+        slot.params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
+        slot.params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
+
+        slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+        slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+        slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+        slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
+
+        if (slot.params.sampling.dry_base < 1.0f) {
+           slot.params.sampling.dry_base = defaults.sampling.dry_base;
         }
 
         // sequence breakers for DRY
@@ -843,8 +931,8 @@ struct server_context {
             // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
 
             if (data.contains("dry_sequence_breakers")) {
-                slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
-                if (slot.sparams.dry_sequence_breakers.empty()) {
+                slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+                if (slot.params.sampling.dry_sequence_breakers.empty()) {
                     send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
                     return false;
                 }
@@ -858,14 +946,14 @@ struct server_context {
         }
         if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
-                auto schema          = json_value(data, "json_schema", json::object());
-                slot.sparams.grammar = json_schema_to_grammar(schema);
+                auto schema                  = json_value(data, "json_schema", json::object());
+                slot.params.sampling.grammar = json_schema_to_grammar(schema);
             } catch (const std::exception & e) {
                 send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
         } else {
-            slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
+            slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
         }
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -875,10 +963,10 @@ struct server_context {
         }
 
         {
-            slot.sparams.logit_bias.clear();
+            slot.params.sampling.logit_bias.clear();
 
             if (json_value(data, "ignore_eos", false) && has_eos_token) {
-                slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
+                slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
             }
 
             const auto & logit_bias = data.find("logit_bias");
@@ -899,12 +987,12 @@ struct server_context {
                         if (el[0].is_number_integer()) {
                             llama_token tok = el[0].get<llama_token>();
                             if (tok >= 0 && tok < n_vocab) {
-                                slot.sparams.logit_bias.push_back({tok, bias});
+                                slot.params.sampling.logit_bias.push_back({tok, bias});
                             }
                         } else if (el[0].is_string()) {
                             auto toks = common_tokenize(model, el[0].get<std::string>(), false);
                             for (auto tok : toks) {
-                                slot.sparams.logit_bias.push_back({tok, bias});
+                                slot.params.sampling.logit_bias.push_back({tok, bias});
                             }
                         }
                     }
@@ -935,16 +1023,16 @@ struct server_context {
                             sampler_names.emplace_back(name);
                         }
                     }
-                    slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
+                    slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
                 } else if (samplers->is_string()){
                     std::string sampler_string;
                     for (const auto & name : *samplers) {
                         sampler_string += name;
                     }
-                    slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
+                    slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
                 }
             } else {
-                slot.sparams.samplers = default_sparams.samplers;
+                slot.params.sampling.samplers = defaults.sampling.samplers;
             }
         }
 
@@ -953,7 +1041,7 @@ struct server_context {
                 common_sampler_free(slot.smpl);
             }
 
-            slot.smpl = common_sampler_init(model, slot.sparams);
+            slot.smpl = common_sampler_init(model, slot.params.sampling);
             if (slot.smpl == nullptr) {
                 // for now, the only error that may happen here is invalid grammar
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -961,6 +1049,12 @@ struct server_context {
             }
         }
 
+        if (slot.ctx_dft) {
+            llama_batch_free(slot.batch_spec);
+
+            slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
+        }
+
         slot.state = SLOT_STATE_STARTED;
 
         SLT_INF(slot, "%s", "processing task\n");
@@ -978,7 +1072,7 @@ struct server_context {
 
     bool process_token(completion_token_output & result, server_slot & slot) {
         // remember which tokens were sampled - used for repetition penalties during sampling
-        const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
+        const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
         slot.sampled = result.tok;
 
         // search stop word and delete it
@@ -1043,7 +1137,7 @@ struct server_context {
         }
 
         // check the limits
-        if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
+        if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
             slot.stopped_limit  = true;
             slot.has_next_token = false;
 
@@ -1136,50 +1230,54 @@ struct server_context {
 
     json get_formated_generation(const server_slot & slot) const {
         std::vector<std::string> samplers;
-        samplers.reserve(slot.sparams.samplers.size());
-        for (const auto & sampler : slot.sparams.samplers) {
+        samplers.reserve(slot.params.sampling.samplers.size());
+        for (const auto & sampler : slot.params.sampling.samplers) {
             samplers.emplace_back(common_sampler_type_to_str(sampler));
         }
 
         return json {
             {"n_ctx",                     slot.n_ctx},
             {"n_predict",                 slot.n_predict},     // Server configured n_predict
-            {"model",                     params.model_alias},
-            {"seed",                      slot.sparams.seed},
+            {"model",                     params_base.model_alias},
+            {"seed",                      slot.params.sampling.seed},
             {"seed_cur",                  slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
-            {"temperature",               slot.sparams.temp},
-            {"dynatemp_range",            slot.sparams.dynatemp_range},
-            {"dynatemp_exponent",         slot.sparams.dynatemp_exponent},
-            {"top_k",                     slot.sparams.top_k},
-            {"top_p",                     slot.sparams.top_p},
-            {"min_p",                     slot.sparams.min_p},
-            {"xtc_probability",           slot.sparams.xtc_probability},
-            {"xtc_threshold",             slot.sparams.xtc_threshold},
-            {"typical_p",                 slot.sparams.typ_p},
-            {"repeat_last_n",             slot.sparams.penalty_last_n},
-            {"repeat_penalty",            slot.sparams.penalty_repeat},
-            {"presence_penalty",          slot.sparams.penalty_present},
-            {"frequency_penalty",         slot.sparams.penalty_freq},
-            {"dry_multiplier",            slot.sparams.dry_multiplier},
-            {"dry_base",                  slot.sparams.dry_base},
-            {"dry_allowed_length",        slot.sparams.dry_allowed_length},
-            {"dry_penalty_last_n",        slot.sparams.dry_penalty_last_n},
-            {"dry_sequence_breakers",     slot.sparams.dry_sequence_breakers},
-            {"mirostat",                  slot.sparams.mirostat},
-            {"mirostat_tau",              slot.sparams.mirostat_tau},
-            {"mirostat_eta",              slot.sparams.mirostat_eta},
-            {"penalize_nl",               slot.sparams.penalize_nl},
+            {"temperature",               slot.params.sampling.temp},
+            {"dynatemp_range",            slot.params.sampling.dynatemp_range},
+            {"dynatemp_exponent",         slot.params.sampling.dynatemp_exponent},
+            {"top_k",                     slot.params.sampling.top_k},
+            {"top_p",                     slot.params.sampling.top_p},
+            {"min_p",                     slot.params.sampling.min_p},
+            {"xtc_probability",           slot.params.sampling.xtc_probability},
+            {"xtc_threshold",             slot.params.sampling.xtc_threshold},
+            {"typical_p",                 slot.params.sampling.typ_p},
+            {"repeat_last_n",             slot.params.sampling.penalty_last_n},
+            {"repeat_penalty",            slot.params.sampling.penalty_repeat},
+            {"presence_penalty",          slot.params.sampling.penalty_present},
+            {"frequency_penalty",         slot.params.sampling.penalty_freq},
+            {"dry_multiplier",            slot.params.sampling.dry_multiplier},
+            {"dry_base",                  slot.params.sampling.dry_base},
+            {"dry_allowed_length",        slot.params.sampling.dry_allowed_length},
+            {"dry_penalty_last_n",        slot.params.sampling.dry_penalty_last_n},
+            {"dry_sequence_breakers",     slot.params.sampling.dry_sequence_breakers},
+            {"mirostat",                  slot.params.sampling.mirostat},
+            {"mirostat_tau",              slot.params.sampling.mirostat_tau},
+            {"mirostat_eta",              slot.params.sampling.mirostat_eta},
+            {"penalize_nl",               slot.params.sampling.penalize_nl},
             {"stop",                      slot.params.antiprompt},
             {"max_tokens",                slot.params.n_predict}, // User configured n_predict
             {"n_keep",                    slot.params.n_keep},
             {"n_discard",                 slot.params.n_discard},
-            {"ignore_eos",                slot.sparams.ignore_eos},
+            {"ignore_eos",                slot.params.sampling.ignore_eos},
             {"stream",                    slot.params.stream},
-          //{"logit_bias",                slot.sparams.logit_bias},
-            {"n_probs",                   slot.sparams.n_probs},
-            {"min_keep",                  slot.sparams.min_keep},
-            {"grammar",                   slot.sparams.grammar},
+          //{"logit_bias",                slot.params.sampling.logit_bias},
+            {"n_probs",                   slot.params.sampling.n_probs},
+            {"min_keep",                  slot.params.sampling.min_keep},
+            {"grammar",                   slot.params.sampling.grammar},
             {"samplers",                  samplers},
+            {"speculative",               slot.can_speculate()},
+            {"speculative.n_max",         slot.params.speculative.n_max},
+            {"speculative.n_min",         slot.params.speculative.n_min},
+            {"speculative.p_min",         slot.params.speculative.p_min},
         };
     }
 
@@ -1216,7 +1314,7 @@ struct server_context {
             {"index",      slot.index},
         };
 
-        if (slot.sparams.n_probs > 0) {
+        if (slot.params.sampling.n_probs > 0) {
             const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
             const size_t probs_pos      = std::min(slot.n_sent_token_probs,                       slot.generated_token_probs.size());
             const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@@ -1249,7 +1347,7 @@ struct server_context {
             {"content",             !slot.params.stream ? slot.generated_text : ""},
             {"id_slot",             slot.id},
             {"stop",                true},
-            {"model",               params.model_alias},
+            {"model",               params_base.model_alias},
             {"tokens_predicted",    slot.n_decoded},
             {"tokens_evaluated",    slot.n_prompt_tokens},
             {"generation_settings", get_formated_generation(slot)},
@@ -1265,7 +1363,7 @@ struct server_context {
             {"index",               slot.index},
         };
 
-        if (slot.sparams.n_probs > 0) {
+        if (slot.params.sampling.n_probs > 0) {
             std::vector<completion_token_output> probs;
             if (!slot.params.stream && slot.stopped_word) {
                 const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
@@ -1422,10 +1520,10 @@ struct server_context {
                             data.at("input_prefix"),
                             data.at("input_suffix"),
                             data.at("input_extra"),
-                            params.n_batch,
-                            params.n_predict,
+                            params_base.n_batch,
+                            params_base.n_predict,
                             slots[0].n_ctx, // TODO: there should be a better way
-                            params.spm_infill,
+                            params_base.spm_infill,
                             tokenized_prompts[i]
                         );
                         create_task(data, tokens);
@@ -1798,7 +1896,7 @@ struct server_context {
         // TODO: simplify and improve
         for (server_slot & slot : slots) {
             if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
-                if (!params.ctx_shift) {
+                if (!params_base.ctx_shift) {
                     // this check is redundant (for good)
                     // we should never get here, because generation should already stopped in process_token()
                     slot.release();
@@ -1864,7 +1962,7 @@ struct server_context {
         int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
 
         // next, batch any pending prompts without exceeding n_batch
-        if (params.cont_batching || batch.n_tokens == 0) {
+        if (params_base.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@@ -1917,7 +2015,7 @@ struct server_context {
                                 continue;
                             }
                         } else {
-                            if (!params.ctx_shift) {
+                            if (!params_base.ctx_shift) {
                                 // if context shift is disabled, we make sure prompt size is smaller than KV size
                                 // TODO: there should be a separate parameter that control prompt truncation
                                 //       context shift should be applied only during the generation phase
@@ -1963,11 +2061,11 @@ struct server_context {
                                 slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
 
                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
-                                if (params.n_cache_reuse > 0) {
+                                if (params_base.n_cache_reuse > 0) {
                                     size_t head_c = slot.n_past; // cache
                                     size_t head_p = slot.n_past; // current prompt
 
-                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
+                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
 
                                     while (head_c < slot.cache_tokens.size() &&
                                            head_p < prompt_tokens.size()) {
@@ -1980,7 +2078,7 @@ struct server_context {
                                             n_match++;
                                         }
 
-                                        if (n_match >= (size_t) params.n_cache_reuse) {
+                                        if (n_match >= (size_t) params_base.n_cache_reuse) {
                                             SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
                                             //for (size_t i = head_p; i < head_p + n_match; i++) {
                                             //    SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@@ -2168,38 +2266,99 @@ struct server_context {
                     continue; // continue loop of slots
                 }
 
-                completion_token_output result;
-                const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+                llama_token id;
+
+                {
+                    completion_token_output result;
+
+                    id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+
+                    slot.i_batch = -1;
+
+                    common_sampler_accept(slot.smpl, id, true);
+
+                    slot.n_decoded += 1;
+                    if (slot.n_decoded == 1) {
+                        slot.t_start_generation = ggml_time_us();
+                        slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
+                        metrics.on_prompt_eval(slot);
+                    }
+
+                    result.tok = id;
+
+                    const auto * cur_p = common_sampler_get_candidates(slot.smpl);
 
-                common_sampler_accept(slot.smpl, id, true);
+                    for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
+                        result.probs.push_back({
+                            cur_p->data[i].id,
+                                i >= cur_p->size ? 0.0f : cur_p->data[i].p,
+                        });
+                    }
+
+                    if (!process_token(result, slot)) {
+                        // release slot because of stop condition
+                        slot.release();
+                        slot.print_timings();
+                        send_final_response(slot);
+                        metrics.on_prediction(slot);
+                        continue;
+                    }
+                }
 
-                slot.n_decoded += 1;
-                if (slot.n_decoded == 1) {
-                    slot.t_start_generation = ggml_time_us();
-                    slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
-                    metrics.on_prompt_eval(slot);
+                // check if the slot supports speculative decoding
+                if (!slot.can_speculate()) {
+                    continue;
                 }
 
-                result.tok = id;
+                struct common_speculative_params params_spec;
+                params_spec.n_draft   = slot.params.speculative.n_max;
+                params_spec.n_reuse   = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
+                params_spec.p_min     = slot.params.speculative.p_min;
 
-                const auto * cur_p = common_sampler_get_candidates(slot.smpl);
+                llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
 
-                for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
-                    result.probs.push_back({
-                        cur_p->data[i].id,
-                        i >= cur_p->size ? 0.0f : cur_p->data[i].p,
-                    });
+                // ignore small drafts
+                if (slot.params.speculative.n_min > (int) draft.size()) {
+                    continue;
                 }
 
-                if (!process_token(result, slot)) {
-                    // release slot because of stop condition
-                    slot.release();
-                    slot.print_timings();
-                    send_final_response(slot);
-                    metrics.on_prediction(slot);
+                // construct the speculation batch
+                common_batch_clear(slot.batch_spec);
+                common_batch_add  (slot.batch_spec, id, slot.n_past, { slot.id }, true);
+
+                for (size_t i = 0; i < draft.size(); ++i) {
+                    common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
+                }
+
+                llama_decode(ctx, slot.batch_spec);
+
+                // the accepted tokens from the speculation
+                const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
+
+                slot.n_past    += ids.size();
+                slot.n_decoded += ids.size();
+
+                slot.cache_tokens.push_back(id);
+                slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
+
+                llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+
+                for (size_t i = 0; i < ids.size(); ++i) {
+                    completion_token_output result;
+
+                    result.tok = ids[i];
+
+                    if (!process_token(result, slot)) {
+                        // release slot because of stop condition
+                        slot.release();
+                        slot.print_timings();
+                        send_final_response(slot);
+                        metrics.on_prediction(slot);
+                        break;
+                    }
                 }
 
-                slot.i_batch = -1;
+                SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
             }
         }
 
@@ -2697,7 +2856,7 @@ int main(int argc, char ** argv) {
     const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
         json data = {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
-            { "total_slots",                 ctx_server.params.n_parallel },
+            { "total_slots",                 ctx_server.params_base.n_parallel },
             { "chat_template",               llama_get_chat_template(ctx_server.model) },
         };
 
@@ -2705,7 +2864,7 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
-        if (!ctx_server.params.endpoint_props) {
+        if (!ctx_server.params_base.endpoint_props) {
             res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
@@ -2718,7 +2877,7 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
-        if (ctx_server.params.embedding) {
+        if (ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
@@ -2824,7 +2983,7 @@ int main(int argc, char ** argv) {
 
     // TODO: maybe merge this function with "handle_completions_generic"
     const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
-        if (ctx_server.params.embedding) {
+        if (ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
@@ -3001,7 +3160,7 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
-        if (!ctx_server.params.reranking || ctx_server.params.embedding) {
+        if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }