#include "arg.h"
#include "common.h"
-#include "log.h"
-#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
+#include "log.h"
+#include "sampling.h"
+#include "speculative.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt;
+
+ struct common_params_sampling sampling;
+ struct common_params_speculative speculative;
};
struct server_slot {
int id;
int id_task = -1;
+ llama_batch batch_spec;
+
+ llama_context * ctx_dft = nullptr;
+
+ common_speculative * spec = nullptr;
+
// the index relative to completion multi-task request
size_t index = 0;
// sampling
json json_schema;
- struct common_params_sampling sparams;
struct common_sampler * smpl = nullptr;
llama_token sampled;
generated_token_probs.clear();
}
- bool has_budget(common_params &global_params) {
+ bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
return state != SLOT_STATE_IDLE;
}
+ bool can_speculate() const {
+ return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+ }
+
void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
};
struct server_context {
+ common_params params_base;
+
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras;
- common_params params;
+ llama_model * model_dft = nullptr;
+ llama_context_params cparams_dft;
llama_batch batch = {};
model = nullptr;
}
+ if (model_dft) {
+ llama_free_model(model_dft);
+ model_dft = nullptr;
+ }
+
// Clear any sampling context
for (server_slot & slot : slots) {
- if (slot.smpl != nullptr) {
- common_sampler_free(slot.smpl);
- }
+ common_sampler_free(slot.smpl);
+ slot.smpl = nullptr;
+
+ llama_free(slot.ctx_dft);
+ slot.ctx_dft = nullptr;
+
+ common_speculative_free(slot.spec);
+ slot.spec = nullptr;
+
+ llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
- bool load_model(const common_params & params_) {
- params = params_;
+ bool load_model(const common_params & params) {
+ SRV_INF("loading model '%s'\n", params.model.c_str());
- common_init_result llama_init = common_init_from_params(params);
+ params_base = params;
+
+ common_init_result llama_init = common_init_from_params(params_base);
model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;
if (model == nullptr) {
- SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
+ SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
return false;
}
add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model);
+ if (!params_base.speculative.model.empty()) {
+ SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
+
+ auto params_dft = params_base;
+
+ params_dft.model = params_base.speculative.model;
+ params_dft.n_ctx = params_base.speculative.n_ctx;
+ params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
+
+ common_init_result llama_init_dft = common_init_from_params(params_dft);
+
+ model_dft = llama_init_dft.model;
+
+ if (model_dft == nullptr) {
+ SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
+ return false;
+ }
+
+ if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
+
+ llama_free (llama_init_dft.context);
+ llama_free_model(llama_init_dft.model);
+
+ return false;
+ }
+
+ cparams_dft = common_context_params_to_llama(params_base);
+ cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
+
+ // the context is not needed - we will create one for each slot
+ llama_free(llama_init_dft.context);
+ }
+
return true;
}
}
void init() {
- const int32_t n_ctx_slot = n_ctx / params.n_parallel;
+ const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
- SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
+ SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
- for (int i = 0; i < params.n_parallel; i++) {
+ for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.n_ctx = n_ctx_slot;
- slot.n_predict = params.n_predict;
+ slot.n_predict = params_base.n_predict;
+
+ if (model_dft) {
+ slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
+
+ slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
+ if (slot.ctx_dft == nullptr) {
+ SRV_ERR("%s", "failed to create draft context\n");
+ return;
+ }
+
+ slot.spec = common_speculative_init(slot.ctx_dft);
+ if (slot.spec == nullptr) {
+ SRV_ERR("%s", "failed to create speculator\n");
+ return;
+ }
+ }
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
- slot.sparams = params.sampling;
+ slot.params.sampling = params_base.sampling;
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
- batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
+ batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
- slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
- auto default_sparams = params.sampling;
+ slot_params defaults;
+ defaults.sampling = params_base.sampling;
+ defaults.speculative = params_base.speculative;
+
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
slot.oaicompat_model = "";
}
- slot.params.stream = json_value(data, "stream", false);
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
- slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
- slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
- slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
- slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
- slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
- slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
- slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
- slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
- slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
- //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
- slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
-
- if (slot.sparams.dry_base < 1.0f)
- {
- slot.sparams.dry_base = default_sparams.dry_base;
+ slot.params.stream = json_value(data, "stream", false);
+ slot.params.cache_prompt = json_value(data, "cache_prompt", false);
+ slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
+ slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+
+ slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
+ slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+
+ slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+ slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+ slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+ slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
+
+ if (slot.params.sampling.dry_base < 1.0f) {
+ slot.params.sampling.dry_base = defaults.sampling.dry_base;
}
// sequence breakers for DRY
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
- slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
- if (slot.sparams.dry_sequence_breakers.empty()) {
+ slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+ if (slot.params.sampling.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
- auto schema = json_value(data, "json_schema", json::object());
- slot.sparams.grammar = json_schema_to_grammar(schema);
+ auto schema = json_value(data, "json_schema", json::object());
+ slot.params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
+ slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
}
{
- slot.sparams.logit_bias.clear();
+ slot.params.sampling.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
- slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
+ slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
const auto & logit_bias = data.find("logit_bias");
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
- slot.sparams.logit_bias.push_back({tok, bias});
+ slot.params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
- slot.sparams.logit_bias.push_back({tok, bias});
+ slot.params.sampling.logit_bias.push_back({tok, bias});
}
}
}
sampler_names.emplace_back(name);
}
}
- slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
+ slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
- slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
+ slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
}
} else {
- slot.sparams.samplers = default_sparams.samplers;
+ slot.params.sampling.samplers = defaults.sampling.samplers;
}
}
common_sampler_free(slot.smpl);
}
- slot.smpl = common_sampler_init(model, slot.sparams);
+ slot.smpl = common_sampler_init(model, slot.params.sampling);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
}
}
+ if (slot.ctx_dft) {
+ llama_batch_free(slot.batch_spec);
+
+ slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
+ }
+
slot.state = SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
- const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
+ const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
slot.sampled = result.tok;
// search stop word and delete it
}
// check the limits
- if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stopped_limit = true;
slot.has_next_token = false;
json get_formated_generation(const server_slot & slot) const {
std::vector<std::string> samplers;
- samplers.reserve(slot.sparams.samplers.size());
- for (const auto & sampler : slot.sparams.samplers) {
+ samplers.reserve(slot.params.sampling.samplers.size());
+ for (const auto & sampler : slot.params.sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
return json {
{"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, // Server configured n_predict
- {"model", params.model_alias},
- {"seed", slot.sparams.seed},
+ {"model", params_base.model_alias},
+ {"seed", slot.params.sampling.seed},
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
- {"temperature", slot.sparams.temp},
- {"dynatemp_range", slot.sparams.dynatemp_range},
- {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
- {"top_k", slot.sparams.top_k},
- {"top_p", slot.sparams.top_p},
- {"min_p", slot.sparams.min_p},
- {"xtc_probability", slot.sparams.xtc_probability},
- {"xtc_threshold", slot.sparams.xtc_threshold},
- {"typical_p", slot.sparams.typ_p},
- {"repeat_last_n", slot.sparams.penalty_last_n},
- {"repeat_penalty", slot.sparams.penalty_repeat},
- {"presence_penalty", slot.sparams.penalty_present},
- {"frequency_penalty", slot.sparams.penalty_freq},
- {"dry_multiplier", slot.sparams.dry_multiplier},
- {"dry_base", slot.sparams.dry_base},
- {"dry_allowed_length", slot.sparams.dry_allowed_length},
- {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
- {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
- {"mirostat", slot.sparams.mirostat},
- {"mirostat_tau", slot.sparams.mirostat_tau},
- {"mirostat_eta", slot.sparams.mirostat_eta},
- {"penalize_nl", slot.sparams.penalize_nl},
+ {"temperature", slot.params.sampling.temp},
+ {"dynatemp_range", slot.params.sampling.dynatemp_range},
+ {"dynatemp_exponent", slot.params.sampling.dynatemp_exponent},
+ {"top_k", slot.params.sampling.top_k},
+ {"top_p", slot.params.sampling.top_p},
+ {"min_p", slot.params.sampling.min_p},
+ {"xtc_probability", slot.params.sampling.xtc_probability},
+ {"xtc_threshold", slot.params.sampling.xtc_threshold},
+ {"typical_p", slot.params.sampling.typ_p},
+ {"repeat_last_n", slot.params.sampling.penalty_last_n},
+ {"repeat_penalty", slot.params.sampling.penalty_repeat},
+ {"presence_penalty", slot.params.sampling.penalty_present},
+ {"frequency_penalty", slot.params.sampling.penalty_freq},
+ {"dry_multiplier", slot.params.sampling.dry_multiplier},
+ {"dry_base", slot.params.sampling.dry_base},
+ {"dry_allowed_length", slot.params.sampling.dry_allowed_length},
+ {"dry_penalty_last_n", slot.params.sampling.dry_penalty_last_n},
+ {"dry_sequence_breakers", slot.params.sampling.dry_sequence_breakers},
+ {"mirostat", slot.params.sampling.mirostat},
+ {"mirostat_tau", slot.params.sampling.mirostat_tau},
+ {"mirostat_eta", slot.params.sampling.mirostat_eta},
+ {"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
- {"ignore_eos", slot.sparams.ignore_eos},
+ {"ignore_eos", slot.params.sampling.ignore_eos},
{"stream", slot.params.stream},
- //{"logit_bias", slot.sparams.logit_bias},
- {"n_probs", slot.sparams.n_probs},
- {"min_keep", slot.sparams.min_keep},
- {"grammar", slot.sparams.grammar},
+ //{"logit_bias", slot.params.sampling.logit_bias},
+ {"n_probs", slot.params.sampling.n_probs},
+ {"min_keep", slot.params.sampling.min_keep},
+ {"grammar", slot.params.sampling.grammar},
{"samplers", samplers},
+ {"speculative", slot.can_speculate()},
+ {"speculative.n_max", slot.params.speculative.n_max},
+ {"speculative.n_min", slot.params.speculative.n_min},
+ {"speculative.p_min", slot.params.speculative.p_min},
};
}
{"index", slot.index},
};
- if (slot.sparams.n_probs > 0) {
+ if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
{"content", !slot.params.stream ? slot.generated_text : ""},
{"id_slot", slot.id},
{"stop", true},
- {"model", params.model_alias},
+ {"model", params_base.model_alias},
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"index", slot.index},
};
- if (slot.sparams.n_probs > 0) {
+ if (slot.params.sampling.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
- params.n_batch,
- params.n_predict,
+ params_base.n_batch,
+ params_base.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
- params.spm_infill,
+ params_base.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
- if (!params.ctx_shift) {
+ if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
slot.release();
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
- if (params.cont_batching || batch.n_tokens == 0) {
+ if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
continue;
}
} else {
- if (!params.ctx_shift) {
+ if (!params_base.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
- if (params.n_cache_reuse > 0) {
+ if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
- SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
n_match++;
}
- if (n_match >= (size_t) params.n_cache_reuse) {
+ if (n_match >= (size_t) params_base.n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
continue; // continue loop of slots
}
- completion_token_output result;
- const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+ llama_token id;
+
+ {
+ completion_token_output result;
+
+ id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+
+ slot.i_batch = -1;
+
+ common_sampler_accept(slot.smpl, id, true);
+
+ slot.n_decoded += 1;
+ if (slot.n_decoded == 1) {
+ slot.t_start_generation = ggml_time_us();
+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
+ metrics.on_prompt_eval(slot);
+ }
+
+ result.tok = id;
+
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
- common_sampler_accept(slot.smpl, id, true);
+ for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
+ result.probs.push_back({
+ cur_p->data[i].id,
+ i >= cur_p->size ? 0.0f : cur_p->data[i].p,
+ });
+ }
+
+ if (!process_token(result, slot)) {
+ // release slot because of stop condition
+ slot.release();
+ slot.print_timings();
+ send_final_response(slot);
+ metrics.on_prediction(slot);
+ continue;
+ }
+ }
- slot.n_decoded += 1;
- if (slot.n_decoded == 1) {
- slot.t_start_generation = ggml_time_us();
- slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
- metrics.on_prompt_eval(slot);
+ // check if the slot supports speculative decoding
+ if (!slot.can_speculate()) {
+ continue;
}
- result.tok = id;
+ struct common_speculative_params params_spec;
+ params_spec.n_draft = slot.params.speculative.n_max;
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
+ params_spec.p_min = slot.params.speculative.p_min;
- const auto * cur_p = common_sampler_get_candidates(slot.smpl);
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
- for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
- result.probs.push_back({
- cur_p->data[i].id,
- i >= cur_p->size ? 0.0f : cur_p->data[i].p,
- });
+ // ignore small drafts
+ if (slot.params.speculative.n_min > (int) draft.size()) {
+ continue;
}
- if (!process_token(result, slot)) {
- // release slot because of stop condition
- slot.release();
- slot.print_timings();
- send_final_response(slot);
- metrics.on_prediction(slot);
+ // construct the speculation batch
+ common_batch_clear(slot.batch_spec);
+ common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
+
+ for (size_t i = 0; i < draft.size(); ++i) {
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
+ }
+
+ llama_decode(ctx, slot.batch_spec);
+
+ // the accepted tokens from the speculation
+ const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
+
+ slot.n_past += ids.size();
+ slot.n_decoded += ids.size();
+
+ slot.cache_tokens.push_back(id);
+ slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
+
+ llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+
+ for (size_t i = 0; i < ids.size(); ++i) {
+ completion_token_output result;
+
+ result.tok = ids[i];
+
+ if (!process_token(result, slot)) {
+ // release slot because of stop condition
+ slot.release();
+ slot.print_timings();
+ send_final_response(slot);
+ metrics.on_prediction(slot);
+ break;
+ }
}
- slot.i_batch = -1;
+ SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
}
}
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
- { "total_slots", ctx_server.params.n_parallel },
+ { "total_slots", ctx_server.params_base.n_parallel },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
};
};
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
- if (!ctx_server.params.endpoint_props) {
+ if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
- if (ctx_server.params.embedding) {
+ if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
- if (ctx_server.params.embedding) {
+ if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
- if (!ctx_server.params.reranking || ctx_server.params.embedding) {
+ if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
return;
}