#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
-#include "mtmd-helper.h"
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
if (only_metrics) {
return json {
- {"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
{"temperature", sampling.temp},
{"dynatemp_range", sampling.dynatemp_range},
{"mirostat", sampling.mirostat},
{"mirostat_tau", sampling.mirostat_tau},
{"mirostat_eta", sampling.mirostat_eta},
- {"max_tokens", n_predict}, // User configured n_predict
+ {"max_tokens", n_predict},
+ {"n_predict", n_predict}, // TODO: deduplicate?
{"n_keep", n_keep},
{"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos},
}
return json {
- {"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
{"temperature", sampling.temp},
{"dynatemp_range", sampling.dynatemp_range},
{"mirostat_tau", sampling.mirostat_tau},
{"mirostat_eta", sampling.mirostat_eta},
{"stop", antiprompt},
- {"max_tokens", n_predict}, // User configured n_predict
+ {"max_tokens", n_predict},
+ {"n_predict", n_predict}, // TODO: deduplicate?
{"n_keep", n_keep},
{"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos},
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
- server_task_type type;
-
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
+ int id_slot = -1;
// used by SERVER_TASK_TYPE_INFERENCE
slot_params params;
- server_tokens prompt_tokens;
- int id_selected_slot = -1;
+ server_tokens tokens;
+
+ server_task_type type;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_adapter_lora_info> set_lora;
+ server_task() = default;
+
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
defaults.n_keep = params_base.n_keep;
+ defaults.n_predict = params_base.n_predict;
defaults.antiprompt = params_base.antiprompt;
// enabling this will output extra debug information in the HTTP responses from the server
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
- params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
-
- params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
- params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
- params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
- params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
- params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
- params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
- params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
- params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
- params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
- params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
- params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
- params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
- params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
- params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
- params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
- params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
- params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
- params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
- params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
- params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
- params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
- params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
- params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
- params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+ params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
-inline std::string stop_type_to_str(stop_type type) {
+static inline std::string stop_type_to_str(stop_type type) {
switch (type) {
case STOP_TYPE_EOS: return "eos";
case STOP_TYPE_WORD: return "word";
}
};
-struct ctx_checkpoint {
- llama_pos pos_min;
- llama_pos pos_max;
-
- std::vector<uint8_t> data;
-};
-
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
slot_params generation_params;
// OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_msg oaicompat_msg;
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_msg oaicompat_msg;
+
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
{ "save_ms", t_ms }
}},
};
- } else {
- return json {
- { "id_slot", id_slot },
- { "filename", filename },
- { "n_restored", n_tokens },
- { "n_read", n_bytes },
- { "timings", {
- { "restore_ms", t_ms }
- }},
- };
}
+
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", n_tokens },
+ { "n_read", n_bytes },
+ { "timings", {
+ { "restore_ms", t_ms }
+ }},
+ };
}
};
}
};
+struct server_prompt_checkpoint {
+ llama_pos pos_min;
+ llama_pos pos_max;
+
+ std::vector<uint8_t> data;
+
+ size_t size() const {
+ return data.size();
+ }
+};
+
+struct server_prompt {
+ server_tokens tokens;
+
+ std::vector<uint8_t> data;
+
+ std::list<server_prompt_checkpoint> checkpoints;
+
+ size_t size() const {
+ size_t res = data.size();
+
+ for (const auto & checkpoint : checkpoints) {
+ res += checkpoint.size();
+ }
+
+ return res;
+ }
+
+ int n_tokens() const {
+ return tokens.size();
+ }
+};
+
+struct server_prompt_cache {
+ server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
+ this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
+ this->limit_tokens = limit_tokens;
+ }
+
+ std::list<server_prompt> states;
+
+ // in bytes, 0 = no limit
+ size_t limit_size = 0;
+
+ // in tokens, 0 = no limit
+ size_t limit_tokens = 0;
+
+ size_t size() const {
+ size_t res = 0;
+
+ for (const auto & state : states) {
+ res += state.size();
+ }
+
+ return res;
+ }
+
+ size_t n_tokens() const {
+ size_t res = 0;
+
+ for (const auto & state : states) {
+ res += state.n_tokens();
+ }
+
+ return res;
+ }
+
+ server_prompt * alloc(const server_prompt & prompt, size_t state_size) {
+ // first check if the current state is contained fully in the cache
+ for (auto it = states.begin(); it != states.end(); ++it) {
+ const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
+
+ if (cur_lcp_len == (int) prompt.tokens.size()) {
+ SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
+ return nullptr;
+ }
+ }
+
+ // next, remove any cached prompts that are fully contained in the current prompt
+ for (auto it = states.begin(); it != states.end();) {
+ const int len = it->tokens.get_common_prefix(prompt.tokens);
+
+ if (len == (int) it->tokens.size()) {
+ SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
+
+ it = states.erase(it);
+ } else {
+ ++it;
+ }
+ }
+
+ std::vector<uint8_t> state_data;
+
+ // check if we can allocate enough memory for the new state
+ try {
+ state_data.resize(state_size);
+ } catch (const std::bad_alloc & e) {
+ SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
+
+ limit_size = std::max<size_t>(1, 0.4*size());
+
+ SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
+
+ update();
+
+ return nullptr;
+ }
+
+ // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
+ auto & cur = states.emplace_back();
+ cur = {
+ /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
+ /*.data =*/ std::move(state_data),
+ /*.checkpoints =*/ prompt.checkpoints,
+ };
+
+ return &cur;
+ }
+
+ bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
+ const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
+
+ float f_keep_best = float(lcp_best) / prompt.tokens.size();
+ float sim_best = float(lcp_best) / tokens_new.size();
+
+ SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+ auto it_best = states.end();
+
+ // find the most similar cached prompt, that would also preserve the most context
+ for (auto it = states.begin(); it != states.end(); ++it) {
+ const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
+
+ const float f_keep_cur = float(lcp_cur) / it->tokens.size();
+ const float sim_cur = float(lcp_cur) / tokens_new.size();
+
+ // don't trash large prompts
+ if (f_keep_cur < 0.25f) {
+ continue;
+ }
+
+ if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
+ f_keep_best = f_keep_cur;
+ sim_best = sim_cur;
+
+ it_best = it;
+ }
+ }
+
+ if (it_best != states.end()) {
+ SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+ const size_t size = it_best->data.size();
+ const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
+ if (n != size) {
+ SRV_WRN("failed to restore state with size %zu\n", size);
+
+ return false;
+ }
+
+ it_best->data.clear();
+ it_best->data.shrink_to_fit();
+
+ prompt = std::move(*it_best);
+
+ states.erase(it_best);
+ }
+
+ return true;
+ }
+
+ void update() {
+ if (limit_size > 0) {
+ // always keep at least one state, regardless of the limits
+ while (states.size() > 1 && size() > limit_size) {
+ if (states.empty()) {
+ break;
+ }
+
+ SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+ states.pop_front();
+ }
+ }
+
+ if (limit_tokens > 0) {
+ while (states.size() > 1 && n_tokens() > limit_tokens) {
+ if (states.empty()) {
+ break;
+ }
+
+ SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+ states.pop_front();
+ }
+ }
+
+ SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n",
+ states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
+
+ for (const auto & state : states) {
+ SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
+ }
+ }
+};
+
struct server_slot {
int id;
- int id_task = -1;
-
- // only used for completion/embedding/infill/rerank
- server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
llama_batch batch_spec = {};
+ // TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
common_speculative * spec = nullptr;
- std::vector<common_adapter_lora_info> lora;
- int32_t alora_invocation_start = -1;
-
- // the index relative to completion multi-task request
- size_t index = 0;
-
- struct slot_params params;
-
- slot_state state = SLOT_STATE_IDLE;
+ std::unique_ptr<const server_task> task;
+ std::unique_ptr<const server_task> task_prev; // used for debugging
// used to determine the slot that has been used the longest
int64_t t_last_used = -1;
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
+ int32_t n_keep = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t i_batch = -1;
- int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
- // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
- int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0;
- // input prompt tokens
- server_tokens prompt_tokens;
+ int32_t n_prompt_tokens() const {
+ return task->tokens.size();
+ }
size_t last_nl_pos = 0;
std::string generated_text;
llama_tokens generated_tokens;
- common_chat_msg chat_msg;
- server_tokens cache_tokens;
+ common_chat_msg chat_msg;
std::vector<completion_token_output> generated_token_probs;
- std::vector<ctx_checkpoint> ctx_checkpoints;
-
bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
+
stop_type stop;
std::string stopping_word;
+ // state
+ slot_state state = SLOT_STATE_IDLE;
+
+ server_prompt prompt;
+
+ void prompt_save(server_prompt_cache & prompt_cache) const {
+ assert(prompt.data.size() == 0);
+
+ const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
+
+ SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
+ (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
+
+ auto * cur = prompt_cache.alloc(prompt, cur_size);
+ if (cur == nullptr) {
+ return;
+ }
+
+ llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
+ }
+
+ void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
+ bool res = prompt_cache.load(prompt, tokens, ctx, id);
+ if (!res) {
+ SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
+ }
+ }
+
+ std::vector<common_adapter_lora_info> lora;
+ int32_t alora_invocation_start = -1;
+
// sampling
json json_schema;
std::vector<std::string> generated_tool_call_ids;
// stats
- size_t n_sent_text = 0; // number of sent text character
+ size_t n_sent_text = 0; // number of sent text character
int64_t t_start_process_prompt;
int64_t t_start_generation;
void reset() {
SLT_DBG(*this, "%s", "\n");
- n_prompt_tokens = 0;
n_prompt_tokens_cache = 0;
- last_nl_pos = 0;
- generated_text = "";
- has_new_line = false;
- truncated = false;
- stop = STOP_TYPE_NONE;
- stopping_word = "";
- n_past = 0;
- n_sent_text = 0;
- task_type = SERVER_TASK_TYPE_COMPLETION;
- chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ last_nl_pos = 0;
+ generated_text = "";
+ has_new_line = false;
+ truncated = false;
+ stop = STOP_TYPE_NONE;
+ stopping_word = "";
+ n_past = 0;
+ n_sent_text = 0;
+ chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
n_draft_total = 0;
n_draft_accepted = 0;
+ task.reset();
+ task_prev.reset();
+
// clear alora start
alora_invocation_start = -1;
}
bool need_embd() const {
- return server_task_type_need_embd(task_type);
+ GGML_ASSERT(task);
+
+ return server_task_type_need_embd(task->type);
}
bool need_logits() const {
- return server_task_type_need_logits(task_type);
+ GGML_ASSERT(task);
+
+ return server_task_type_need_logits(task->type);
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
}
bool can_batch_with(server_slot & other_slot) const {
- return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
+ GGML_ASSERT(task);
+
+ return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
}
bool has_budget(const common_params & global_params) {
- if (params.n_predict == -1 && global_params.n_predict == -1) {
+ GGML_ASSERT(task);
+
+ if (task->params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
n_remaining = -1;
- if (params.n_predict != -1) {
- n_remaining = params.n_predict - n_decoded;
+ if (task->params.n_predict != -1) {
+ n_remaining = task->params.n_predict - n_decoded;
} else if (global_params.n_predict != -1) {
n_remaining = global_params.n_predict - n_decoded;
}
}
bool can_speculate() const {
- return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+ return ctx_dft;
}
void add_token(const completion_token_output & token) {
void release() {
if (is_processing()) {
+ GGML_ASSERT(task);
+
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
state = SLOT_STATE_IDLE;
+
+ task_prev = std::move(task);
+ task.reset();
+
callback_on_release(id);
}
}
result_timings timings;
timings.cache_n = n_prompt_tokens_cache;
- timings.prompt_n = n_prompt_tokens_processed;
- timings.prompt_ms = t_prompt_processing;
+ timings.prompt_n = n_prompt_tokens_processed;
+ timings.prompt_ms = t_prompt_processing;
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
- timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
+ timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
- timings.predicted_n = n_decoded;
- timings.predicted_ms = t_token_generation;
+ timings.predicted_n = n_decoded;
+ timings.predicted_ms = t_token_generation;
timings.predicted_per_token_ms = t_token_generation / n_decoded;
- timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
+ timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
// Add speculative metrics
if (n_draft_total > 0) {
- timings.draft_n = n_draft_total;
+ timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
+ GGML_ASSERT(task);
+
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
- params.oaicompat_chat_syntax);
+ task->params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
- new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
+ new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
+ GGML_ASSERT(task);
+
size_t stop_pos = std::string::npos;
- for (const std::string & word : params.antiprompt) {
+ for (const std::string & word : task->params.antiprompt) {
size_t pos;
if (is_full_stop) {
}
json to_json(bool only_metrics = false) const {
- if (only_metrics) {
- return json {
- {"id", id},
- {"id_task", id_task},
- {"n_ctx", n_ctx},
- {"speculative", can_speculate()},
- {"is_processing", is_processing()},
- {"params", params.to_json(true)},
- {"next_token",
- {
- {"has_next_token", has_next_token},
- {"has_new_line", has_new_line},
- {"n_remain", n_remaining},
- {"n_decoded", n_decoded},
- }
- },
- };
- }
+ json res;
- return json {
+ res = {
{"id", id},
- {"id_task", id_task},
{"n_ctx", n_ctx},
{"speculative", can_speculate()},
{"is_processing", is_processing()},
- {"params", params.to_json()},
- {"prompt", prompt_tokens.detokenize(ctx, true)},
- {"next_token",
+ };
+
+ const auto & ptask = task ? task : task_prev;
+
+ if (ptask) {
+ res["id_task"] = ptask->id;
+ res["params"] = ptask->params.to_json(only_metrics);
+ res["next_token"] = {
{
{"has_next_token", has_next_token},
{"has_new_line", has_new_line},
{"n_remain", n_remaining},
{"n_decoded", n_decoded},
- {"stopping_word", stopping_word},
}
- },
- };
+ };
+
+ if (!only_metrics) {
+ res["prompt"] = ptask->tokens.detokenize(ctx, true);
+ res["generated"] = generated_text;
+ }
+ }
+
+ return res;
}
};
// slots / clients
std::vector<server_slot> slots;
- json default_generation_settings_for_props;
+
+ int slots_debug = 0;
server_queue queue_tasks;
server_response queue_results;
+ std::unique_ptr<server_prompt_cache> prompt_cache;
+
server_metrics metrics;
// Necessary similarity of prompt for slot selection
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
- slot.n_predict = params_base.n_predict;
slot.mctx = mctx;
- slot.cache_tokens.has_mtmd = mctx != nullptr;
+ slot.prompt.tokens.has_mtmd = mctx != nullptr;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
SRV_ERR("%s", "failed to create speculator\n");
return;
}
- for (auto &pair : params_base.speculative.replacements) {
+ for (auto & pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
- slot.params.sampling = params_base.sampling;
- slot.params.n_keep = params_base.n_keep;
-
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
};
slots.push_back(std::move(slot));
}
- default_generation_settings_for_props = slots[0].to_json();
+ {
+ const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
+ slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
+
+ if (slots_debug) {
+ SRV_WRN("slots debug = %d\n", slots_debug);
+ }
+ }
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
metrics.init();
+ if (params_base.cache_ram_mib != 0) {
+ if (params_base.cache_ram_mib < 0) {
+ SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
+ } else {
+ SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
+ }
+ SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
+
+ prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
+ } else {
+ SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
+ }
+ SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
+
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
- SRV_INF("Enable thinking? %d\n", enable_thinking);
+ SRV_INF("thinking = %d\n", enable_thinking);
oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
server_slot * get_available_slot(const server_task & task) {
server_slot * ret = nullptr;
+ bool update_cache = false;
+
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
- int lcs_len = 0;
- float similarity = 0;
+ float sim_best = 0;
for (server_slot & slot : slots) {
// skip the slot if it is not available
continue;
}
+ const auto & tokens = slot.prompt.tokens;
+
// skip the slot if it does not contains cached tokens
- if (slot.cache_tokens.empty()) {
+ if (tokens.empty()) {
continue;
}
- // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
- int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
-
- // fraction of the common subsequence length compared to the current slot's prompt length
- float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
+ // fraction of the Longest Common Prefix length with respect to the input prompt length
+ const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
// select the current slot if the criteria match
- if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
- lcs_len = cur_lcs_len;
- similarity = cur_similarity;
+ if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
+ sim_best = sim_cur;
+
ret = &slot;
}
}
if (ret != nullptr) {
- SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity);
+ const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
+
+ SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
+ sim_best, slot_prompt_similarity, f_keep);
+
+ // if we are about to lose a large portion of the existing context - save it in the prompt cache
+ if (f_keep < 0.5f) {
+ update_cache = true;
+ }
}
}
if (ret != nullptr) {
SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
+
+ update_cache = true;
+ }
+ }
+
+ if (ret) {
+ const auto & tokens = ret->prompt.tokens;
+
+ update_cache = update_cache && prompt_cache;
+
+ // cache prompts only for completion tasks
+ update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
+
+ // don't update the cache if the slot's context is empty
+ update_cache = update_cache && tokens.size() > 0;
+
+ // TODO: mtmd does not support prompt cache
+ update_cache = update_cache && (ret->mctx == nullptr);
+
+ if (update_cache) {
+ SRV_WRN("%s", "updating prompt cache\n");
+
+ const int64_t t_start = ggml_time_us();
+
+ ret->prompt_save(*prompt_cache);
+ ret->prompt_load(*prompt_cache, task.tokens);
+
+ prompt_cache->update();
+
+ SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
- slot.id_task = task.id;
- slot.index = task.index;
- slot.task_type = task.type;
- slot.params = std::move(task.params);
- slot.prompt_tokens = std::move(task.prompt_tokens);
- if (!are_lora_equal(slot.params.lora, slot.lora)) {
+ if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora has changed, check to see if the cache should be cleared
- if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
- SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
- slot.cache_tokens.clear();
+ if (lora_should_clear_cache(slot.lora, task.params.lora)) {
+ SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
+ slot.prompt.tokens.clear();
} else {
- SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
+ SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size());
}
- slot.lora = slot.params.lora;
+ slot.lora = task.params.lora;
}
// if using alora, make sure it's only a single one requested and active
- size_t alora_invocation_start = slot.prompt_tokens.size();
+ size_t alora_invocation_start = task.tokens.size();
if (lora_all_alora(slot.lora)) {
-
const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
// TODO: This will error out if a user requests two aloras, but only
// provides the activation string for one. We could, instead search
// scan backwards through the prompt tokens to find the last
// occurrence of the invocation sequence
int match_idx = static_cast<int>(n_invocation_tokens) - 1;
- for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
+ for (int i = task.tokens.size() - 1; i >= 0; --i) {
// the token in this position matches the next token to find in
// the invocation sequence
- if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
+ if (task.tokens[i] == invocation_tokens[match_idx]) {
// if it's a full match, we've found the start
if (match_idx == 0) {
alora_invocation_start = i;
}
// if the activation string is not found, disable the alora
- if (alora_invocation_start == slot.prompt_tokens.size()) {
+ if (alora_invocation_start == task.tokens.size()) {
SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
slot.lora[enabled_ids[0]].scale = 0.0f;
} else {
}
}
- if (!slot.prompt_tokens.validate(ctx)) {
+ if (!task.tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
- SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
- if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
- // Might be better to reject the request with a 400 ?
- SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
- slot.params.n_predict = slot.n_predict;
- }
+ SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
+ // initialize samplers
{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
}
- slot.smpl = common_sampler_init(model, slot.params.sampling);
+ slot.smpl = common_sampler_init(model, task.params.sampling);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
}
}
+ // initialize draft batch
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
- slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
+ slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
}
+ slot.task = std::make_unique<const server_task>(std::move(task));
+
slot.state = SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
slot.sampled = result.tok;
slot.generated_text += token_str;
- if (slot.params.return_tokens) {
+ if (slot.task->params.return_tokens) {
slot.generated_tokens.push_back(result.tok);
}
slot.has_next_token = true;
}
slot.add_token(result);
- if (slot.params.stream) {
+ if (slot.task->params.stream) {
send_partial_response(slot, result, false);
}
}
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
+ SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
}
if (slot.has_new_line) {
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
- if (slot.params.n_indent > 0) {
+ if (slot.task->params.n_indent > 0) {
// check the current indentation
// TODO: improve by not doing it more than once for each new line
if (slot.last_nl_pos > 0) {
pos++;
}
- if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
+ if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
slot.has_new_line = true;
// if we have seen a new line, we stop after a certain time limit, but only upon another new line
- if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+ if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
}
}
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
- slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
+ slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}
if (llama_vocab_is_eog(vocab, result.tok)) {
const auto n_ctx_train = llama_model_n_ctx_train(model);
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+ if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
- slot.params.n_predict, n_ctx_train);
+ slot.task->params.n_predict, n_ctx_train);
}
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
}
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
- size_t n_probs = slot.params.sampling.n_probs;
+ size_t n_probs = slot.task->params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) {
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
- send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
+ send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
}
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
}
// if multimodal is enabled, send an error and return false
- bool ensure_no_mtmd(const int id_task) {
+ bool check_no_mtmd(const int id_task) {
if (mctx) {
send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
return false;
void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
auto res = std::make_unique<server_task_result_cmpl_partial>();
- res->id = slot.id_task;
- res->index = slot.index;
+ res->id = slot.task->id;
+ res->index = slot.task->index;
if (is_progress) {
res->is_progress = true;
- res->progress.total = slot.n_prompt_tokens;
+ res->progress.total = slot.n_prompt_tokens();
res->progress.cache = slot.n_prompt_tokens_cache;
- res->progress.processed = slot.cache_tokens.size();
+ res->progress.processed = slot.prompt.tokens.size();
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
} else {
res->content = tkn.text_to_send;
}
res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
- res->post_sampling_probs = slot.params.post_sampling_probs;
+ res->n_prompt_tokens = slot.n_prompt_tokens();
+ res->post_sampling_probs = slot.task->params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
- res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+ res->verbose = slot.task->params.verbose;
+ res->oaicompat = slot.task->params.oaicompat;
+ res->oaicompat_model = slot.task->params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
// populate res.probs_output
- if (slot.params.sampling.n_probs > 0) {
+ if (slot.task->params.sampling.n_probs > 0) {
res->prob_output = tkn; // copy the token probs
}
// populate timings if this is final response or timings_per_token is enabled
- if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
+ if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
res->timings = slot.get_timings();
}
void send_final_response(server_slot & slot) {
auto res = std::make_unique<server_task_result_cmpl_final>();
- res->id = slot.id_task;
- res->id_slot = slot.id;
- res->index = slot.index;
+ res->id = slot.task->id;
+ res->id_slot = slot.id;
+
+ res->index = slot.task->index;
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
- res->prompt = slot.prompt_tokens.detokenize(ctx, true);
- res->response_fields = std::move(slot.params.response_fields);
+ res->prompt = slot.task->tokens.detokenize(ctx, true);
+ res->response_fields = std::move(slot.task->params.response_fields);
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_prompt_tokens = slot.n_prompt_tokens();
res->n_tokens_cached = slot.n_past;
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
- res->post_sampling_probs = slot.params.post_sampling_probs;
+ res->post_sampling_probs = slot.task->params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->stream = slot.params.stream;
- res->include_usage = slot.params.include_usage;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
- res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
- res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
+ res->verbose = slot.task->params.verbose;
+ res->stream = slot.task->params.stream;
+ res->include_usage = slot.task->params.include_usage;
+ res->oaicompat = slot.task->params.oaicompat;
+ res->oaicompat_model = slot.task->params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
+ res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
- if (slot.params.sampling.n_probs > 0) {
- if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
+ if (slot.task->params.sampling.n_probs > 0) {
+ if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
}
}
- res->generation_params = slot.params; // copy the parameters
+ res->generation_params = slot.task->params; // copy the parameters
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
- res->id = slot.id_task;
- res->index = slot.index;
- res->n_tokens = slot.n_prompt_tokens;
- res->oaicompat = slot.params.oaicompat;
+ res->id = slot.task->id;
+ res->index = slot.task->index;
+ res->n_tokens = slot.n_prompt_tokens();
+ res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_model_n_embd(model);
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
- common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize);
+ common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
- } else {
- res->embedding.emplace_back(embd, embd + n_embd);
}
+
+ res->embedding.emplace_back(embd, embd + n_embd);
}
SLT_DBG(slot, "%s", "sending embeddings\n");
void send_rerank(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_rerank>();
- res->id = slot.id_task;
- res->index = slot.index;
- res->n_tokens = slot.n_prompt_tokens;
+ res->id = slot.task->id;
+ res->index = slot.task->index;
+ res->n_tokens = slot.n_prompt_tokens();
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
- const int id_slot = task.id_selected_slot;
+ const int id_slot = task.id_slot;
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
{
// release slot linked with the task id
for (auto & slot : slots) {
- if (slot.id_task == task.id_target) {
+ if (slot.task && slot.task->id == task.id_target) {
slot.release();
break;
}
int n_processing_slots = 0;
for (server_slot & slot : slots) {
- json slot_data = slot.to_json(true);
+ json slot_data = slot.to_json(slots_debug == 0);
if (slot.is_processing()) {
n_processing_slots++;
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
- if (!ensure_no_mtmd(task.id)) {
+ if (!check_no_mtmd(task.id)) {
break;
}
break;
}
- const size_t token_count = slot->cache_tokens.size();
+ const size_t token_count = slot->prompt.tokens.size();
const int64_t t_start = ggml_time_us();
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
- const llama_tokens & tokens = slot->cache_tokens.get_text_tokens();
+ const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
- if (!ensure_no_mtmd(task.id)) break;
+ if (!check_no_mtmd(task.id)) break;
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
- slot->cache_tokens.clear(); // KV may already been invalidated?
+ slot->prompt.tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
tokens.resize(token_count);
- slot->cache_tokens.clear();
- slot->cache_tokens.insert(tokens);
+ slot->prompt.tokens.clear();
+ slot->prompt.tokens.insert(tokens);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
- if (!ensure_no_mtmd(task.id)) break;
+ if (!check_no_mtmd(task.id)) {
+ break;
+ }
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
}
// Erase token cache
- const size_t n_erased = slot->cache_tokens.size();
+ const size_t n_erased = slot->prompt.tokens.size();
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
- slot->cache_tokens.clear();
+ slot->prompt.tokens.clear();
auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
- slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+ slot.release();
continue;
}
}
// Shift context
- const int n_keep = slot.params.n_keep + add_bos_token;
+ int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
+
+ if (add_bos_token) {
+ n_keep += 1;
+ }
+
+ n_keep = std::min(slot.n_ctx - 4, n_keep);
+
const int n_left = slot.n_past - n_keep;
- const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
+ const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
// add generated tokens to cache
{
- llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
+ llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
- new_tokens.resize(slot.cache_tokens.size() - n_discard);
- slot.cache_tokens.clear();
- slot.cache_tokens.insert(new_tokens);
+ new_tokens.resize(slot.prompt.tokens.size() - n_discard);
+ slot.prompt.tokens.clear();
+ slot.prompt.tokens.insert(new_tokens);
}
slot.n_past -= n_discard;
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](server_slot & slot, llama_token token) {
- return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
+ return params_base.special ||
+ slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
- slot.cache_tokens.push_back(slot.sampled);
+ slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
- slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
+ slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
}
// process in chunks of params.n_batch
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
- auto & prompt_tokens = slot.prompt_tokens;
+ const auto & input_tokens = slot.task->tokens;
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
slot.t_start_generation = 0;
slot.n_past = 0;
- slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
- SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
+ slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
// print prompt tokens (for debugging)
/*if (1) {
// first 16 tokens (avoid flooding logs)
- for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
} else {
// all
- for (int i = 0; i < (int) prompt_tokens.size(); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ for (int i = 0; i < (int) input_tokens.size(); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
}
}*/
// empty prompt passed -> release the slot and send empty response
- if (prompt_tokens.empty()) {
+ if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
- slot.release();
slot.print_timings();
send_final_response(slot);
+ slot.release();
+
continue;
}
// TODO: support memory-less logits computation
if (slot.need_logits() && !llama_get_memory(ctx)) {
- slot.release();
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
+ slot.release();
continue;
}
if (!slot.can_split()) {
- if (slot.n_prompt_tokens > n_ubatch) {
- slot.release();
+ if (slot.n_prompt_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
+ slot.release();
continue;
}
- if (slot.n_prompt_tokens > slot.n_ctx) {
- slot.release();
+ if (slot.n_prompt_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+ slot.release();
continue;
}
} else {
- if (!params_base.ctx_shift) {
- // if context shift is disabled, we make sure prompt size is smaller than KV size
- // TODO: there should be a separate parameter that control prompt truncation
- // context shift should be applied only during the generation phase
- if (slot.n_prompt_tokens >= slot.n_ctx) {
- slot.release();
- send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
- continue;
- }
- }
- if (slot.params.n_keep < 0) {
- slot.params.n_keep = slot.n_prompt_tokens;
- }
- slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
-
- // if input prompt is too big, truncate it
- if (slot.n_prompt_tokens >= slot.n_ctx) {
- if (mctx) {
- // we should never reach this
- GGML_ABORT("not supported by multimodal");
- }
- const int n_left = slot.n_ctx - slot.params.n_keep;
-
- const int n_block_size = n_left / 2;
- const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
-
- const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens();
- llama_tokens new_tokens(
- curr_tokens.begin(),
- curr_tokens.begin() + slot.params.n_keep);
-
- new_tokens.insert(
- new_tokens.end(),
- curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
- curr_tokens.end());
-
- prompt_tokens.clear();
- prompt_tokens.insert(new_tokens);
-
- slot.truncated = true;
- slot.n_prompt_tokens = prompt_tokens.size();
-
- SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
-
- GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
+ if (slot.n_prompt_tokens() >= slot.n_ctx) {
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
+ slot.release();
+ continue;
}
- if (slot.params.cache_prompt) {
+ if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
- slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
+ slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start >= 0) {
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
- while (head_c < slot.cache_tokens.size() &&
- head_p < prompt_tokens.size()) {
+ while (head_c < slot.prompt.tokens.size() &&
+ head_p < input_tokens.size()) {
size_t n_match = 0;
- while (head_c + n_match < slot.cache_tokens.size() &&
- head_p + n_match < prompt_tokens.size() &&
- slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
+ while (head_c + n_match < slot.prompt.tokens.size() &&
+ head_p + n_match < input_tokens.size() &&
+ slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
n_match++;
}
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
- slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]);
+ slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
slot.n_past++;
}
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
- if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
+ if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
- SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+ SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
+ // when the prompt prefix does not match, print the tokens around the mismatch
+ // this is useful for debugging prompt caching
+ {
+ const int np0 = std::max<int>(slot.n_past - 4, 0);
+ const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
+
+ std::stringstream ss0;
+ std::stringstream ss1;
+
+ std::stringstream st0;
+ std::stringstream st1;
+
+ ss0 << "old: ... ";
+ ss1 << "new: ... ";
+
+ for (int i = np0; i < np1; i++) {
+ if (i == slot.n_past) {
+ ss0 << " | ";
+ ss1 << " | ";
+ }
+
+ {
+ const auto token = slot.prompt.tokens[i];
+ const auto piece = common_token_to_piece(ctx, token);
+ ss0 << piece;
+ st0 << std::setw(8) << token;
+ }
+
+ {
+ const auto token = slot.task->tokens[i];
+ const auto piece = common_token_to_piece(ctx, token);
+ ss1 << piece;
+ st1 << std::setw(8) << token;
+ }
+ }
+
+ SLT_WRN(slot, "%s\n", ss0.str().c_str());
+ SLT_WRN(slot, "%s\n", ss1.str().c_str());
+
+ SLT_WRN(slot, "%s\n", st0.str().c_str());
+ SLT_WRN(slot, "%s\n", st1.str().c_str());
+ }
+
if (pos_min > pos_min_thold) {
- SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
+ SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
- slot.ctx_checkpoints.rbegin(),
- slot.ctx_checkpoints.rend(),
+ slot.prompt.checkpoints.rbegin(),
+ slot.prompt.checkpoints.rend(),
[&](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
return cur.pos_min < pos_min_thold;
}
);
- bool do_reset = it == slot.ctx_checkpoints.rend();
- //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
+ bool do_reset = it == slot.prompt.checkpoints.rend();
if (!do_reset) {
// restore the context checkpoint
- const size_t ctx_checkpoint_size = it->data.size();
- const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+ const size_t checkpoint_size = it->data.size();
+ const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- if (n != ctx_checkpoint_size) {
- SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
+ if (n != checkpoint_size) {
+ SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
- SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
+ SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
{
// erase any checkpoints with pos_min > pos_min_thold
- for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
- const auto & cur = slot.ctx_checkpoints[i];
+ for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
+ const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
- slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
+ it = slot.prompt.checkpoints.erase(it);
+ } else {
+ ++it;
}
}
}
}
// [TAG_PROMPT_LOGITS]
- if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
- SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
+ if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
+ SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
slot.n_past--;
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
}
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
- if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
+ if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
continue;
}
}
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
// remove the non-common part from the cache
- slot.cache_tokens.keep_first(slot.n_past);
+ slot.prompt.tokens.keep_first(slot.n_past);
// check if we should process the image
- if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
+ if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
- int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
- int32_t n_pos = new_n_past - slot.n_past;
-
+ int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
- slot.release();
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
+ slot.release();
continue;
}
// add the image chunk to cache
{
- const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
- slot.cache_tokens.push_back(chunk.get()); // copy
+ const auto & chunk = input_tokens.find_chunk(slot.n_past);
+ slot.prompt.tokens.push_back(chunk.get()); // copy
}
+ const int32_t n_pos = new_n_past - slot.n_past;
+
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += n_pos;
}
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
+ // make checkpoints only for completion tasks
+ do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
+
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
);
// add prompt tokens for processing in the current batch
- while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
+ while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
// get next token to process
- llama_token cur_tok = slot.prompt_tokens[slot.n_past];
+ llama_token cur_tok = input_tokens[slot.n_past];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
}
// embedding requires all tokens in the batch to be output
- const bool need_embd = server_task_type_need_embd(slot.task_type);
-
- common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
- slot.cache_tokens.push_back(cur_tok);
+ common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
+ slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
- if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
+ if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
break;
}
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
- SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
// entire prompt has been processed
- if (slot.n_past == slot.n_prompt_tokens) {
+ if (slot.n_past == slot.n_prompt_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
- GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size());
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
- for (int i = 0; i < slot.n_prompt_tokens; ++i) {
- llama_token id = slot.prompt_tokens[i];
+ for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
+ llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
}
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
// no need to create checkpoints that are too close together
- do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
+ do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
if (do_checkpoint) {
- while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+ while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
- const auto & cur = slot.ctx_checkpoints.front();
+ const auto & cur = slot.prompt.checkpoints.front();
+
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
- slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
+ slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
+ auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
- (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+ (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
}
}
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
- slot.release();
send_error(slot, err);
+ slot.release();
}
break;
}
for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
- if (slot.params.stream && slot.params.return_progress) {
+ if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
}
}
}
if (slot.state == SLOT_STATE_DONE_PROMPT) {
- if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
+ if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
continue; // continue loop of slots
}
- if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
+ if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
- if (slot.params.sampling.n_probs > 0) {
- populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
+ if (slot.task->params.sampling.n_probs > 0) {
+ populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
- slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
+ slot.release();
+
continue;
}
}
}
// determine the max draft that fits the current slot state
- int n_draft_max = slot.params.speculative.n_max;
+ int n_draft_max = slot.task->params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
- if (n_draft_max < slot.params.speculative.n_min) {
- SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
+ if (n_draft_max < slot.task->params.speculative.n_min) {
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
continue;
}
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
- params_spec.n_draft = n_draft_max;
- params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
- params_spec.p_min = slot.params.speculative.p_min;
+ params_spec.n_draft = n_draft_max;
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
+ params_spec.p_min = slot.task->params.speculative.p_min;
- const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
+ const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// ignore small drafts
- if (slot.params.speculative.n_min > (int) draft.size()) {
- SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
+ if (slot.task->params.speculative.n_min > (int) draft.size()) {
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
continue;
}
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
- slot.cache_tokens.push_back(id);
- slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
+ slot.prompt.tokens.push_back(id);
+ slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
// TODO: set result.probs
if (!process_token(result, slot)) {
- // release slot because of stop condition
- slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
+ slot.release();
+
break;
}
}
}
// TODO: get rid of this dynamic_cast
- auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
- GGML_ASSERT(res_metrics != nullptr);
+ auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
+ GGML_ASSERT(res_task != nullptr);
// optionally return "fail_on_no_slot" error
if (req.has_param("fail_on_no_slot")) {
- if (res_metrics->n_idle_slots == 0) {
+ if (res_task->n_idle_slots == 0) {
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
- res_ok(res, res_metrics->slots_data);
+ res_ok(res, res_task->slots_data);
};
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
}
// TODO: get rid of this dynamic_cast
- auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
- GGML_ASSERT(res_metrics != nullptr);
+ auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
+ GGML_ASSERT(res_task != nullptr);
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
- {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
+ {"value", (uint64_t) res_task->n_prompt_tokens_processed_total}
}, {
{"name", "prompt_seconds_total"},
{"help", "Prompt process time"},
- {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
+ {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
- {"value", (uint64_t) res_metrics->n_tokens_predicted_total}
+ {"value", (uint64_t) res_task->n_tokens_predicted_total}
}, {
{"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"},
- {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
+ {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3}
}, {
{"name", "n_decode_total"},
{"help", "Total number of llama_decode() calls"},
- {"value", res_metrics->n_decode_total}
+ {"value", res_task->n_decode_total}
}, {
{"name", "n_past_max"},
{"help", "Largest observed n_past."},
- {"value", res_metrics->n_past_max}
+ {"value", res_task->n_past_max}
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},
- {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)}
+ {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
- {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
+ {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
- {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
+ {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
},{
{"name", "requests_processing"},
{"help", "Number of requests processing."},
- {"value", (uint64_t) res_metrics->n_processing_slots}
+ {"value", (uint64_t) res_task->n_processing_slots}
},{
{"name", "requests_deferred"},
{"help", "Number of requests deferred."},
- {"value", (uint64_t) res_metrics->n_tasks_deferred}
+ {"value", (uint64_t) res_task->n_tasks_deferred}
}}}
};
}
}
- res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
+ res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start));
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
};
const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
+ json default_generation_settings_for_props;
+
+ {
+ slot_params params;
+
+ params.sampling = ctx_server.params_base.sampling;
+
+ default_generation_settings_for_props = json {
+ {"params", params.to_json(true)},
+ {"n_ctx", ctx_server.slots[0].n_ctx},
+ };
+ }
+
// this endpoint is publicly available, please only return what is safe to be exposed
json data = {
- { "default_generation_settings", ctx_server.default_generation_settings_for_props },
+ { "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json {
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = std::move(inputs[i]);
- task.params = server_task::params_from_json_cmpl(
+ task.tokens = std::move(inputs[i]);
+ task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
data);
- task.id_selected_slot = json_value(data, "id_slot", -1);
+ task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
- task.id = ctx_server.queue_tasks.get_new_id();
- task.index = i;
- task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+ task.tokens = std::move(tokenized_prompts[i]);
// OAI-compat
task.params.oaicompat = oaicompat;
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
- server_task task = server_task(SERVER_TASK_TYPE_RERANK);
- task.id = ctx_server.queue_tasks.get_new_id();
- task.index = i;
- task.prompt_tokens = std::move(tmp);
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+ task.tokens = std::move(tmp);
tasks.push_back(std::move(task));
}
#endif
LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__,
- is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
+ is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str());
// this call blocks the main thread until queue_tasks.terminate() is called