};
enum server_task_type {
- SERVER_TASK_TYPE_INFERENCE,
+ SERVER_TASK_TYPE_COMPLETION,
+ SERVER_TASK_TYPE_EMBEDDING,
+ SERVER_TASK_TYPE_RERANK,
+ SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SET_LORA,
};
-enum server_task_inf_type {
- SERVER_TASK_INF_TYPE_COMPLETION,
- SERVER_TASK_INF_TYPE_EMBEDDING,
- SERVER_TASK_INF_TYPE_RERANK,
- SERVER_TASK_INF_TYPE_INFILL,
-};
-
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
-struct server_task {
- int id = -1; // to be filled by server_queue
- int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
-
- llama_tokens prompt_tokens;
- server_task_type type;
-
- // TODO @ngxson : we should get rid of json type here
- json data;
-
- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
-
- // utility function
- static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
- std::unordered_set<int> ids(tasks.size());
- for (size_t i = 0; i < tasks.size(); i++) {
- ids.insert(tasks[i].id);
- }
- return ids;
- }
-};
-
struct slot_params {
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
std::vector<std::string> antiprompt;
bool timings_per_token = false;
+ bool ignore_eos = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
{"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos},
{"stream", stream},
- //{"logit_bias", sampling.logit_bias},
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
}
};
+struct server_task {
+ int id = -1; // to be filled by server_queue
+ int index = -1; // used when there are multiple prompts (batch request)
+
+ server_task_type type;
+
+ // used by SERVER_TASK_TYPE_CANCEL
+ int id_target = -1;
+
+ // used by SERVER_TASK_TYPE_INFERENCE
+ slot_params params;
+ llama_tokens prompt_tokens;
+ int id_selected_slot = -1;
+
+ // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
+ struct slot_action {
+ int slot_id;
+ std::string filename;
+ std::string filepath;
+ };
+ slot_action slot_action;
+
+ // used by SERVER_TASK_TYPE_METRICS
+ bool metrics_reset_bucket = false;
+
+ server_task(server_task_type type) : type(type) {}
+
+ static slot_params params_from_json_cmpl(
+ const llama_model * model,
+ const common_params & params_base,
+ const json & data) {
+ slot_params params;
+
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
+ slot_params defaults;
+ defaults.sampling = params_base.sampling;
+ defaults.speculative = params_base.speculative;
+
+ // enabling this will output extra debug information in the HTTP responses from the server
+ params.verbose = params_base.verbosity > 9;
+ params.timings_per_token = json_value(data, "timings_per_token", false);
+
+ params.stream = json_value(data, "stream", false);
+ params.cache_prompt = json_value(data, "cache_prompt", true);
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+
+ params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+ params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+ params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+ params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
+ params.speculative.n_min = std::max(params.speculative.n_min, 2);
+ params.speculative.n_max = std::max(params.speculative.n_max, 0);
+
+ if (params.sampling.dry_base < 1.0f) {
+ params.sampling.dry_base = defaults.sampling.dry_base;
+ }
+
+ // sequence breakers for DRY
+ {
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+ if (data.contains("dry_sequence_breakers")) {
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+ if (params.sampling.dry_sequence_breakers.empty()) {
+ throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
+ }
+ }
+ }
+
+ // process "json_schema" and "grammar"
+ if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
+ throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
+ }
+ if (data.contains("json_schema") && !data.contains("grammar")) {
+ try {
+ auto schema = json_value(data, "json_schema", json::object());
+ params.sampling.grammar = json_schema_to_grammar(schema);
+ } catch (const std::exception & e) {
+ throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
+ }
+ } else {
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+ }
+
+ {
+ params.sampling.logit_bias.clear();
+ params.ignore_eos = json_value(data, "ignore_eos", false);
+
+ const auto & logit_bias = data.find("logit_bias");
+ if (logit_bias != data.end() && logit_bias->is_array()) {
+ const int n_vocab = llama_n_vocab(model);
+ for (const auto & el : *logit_bias) {
+ // TODO: we may want to throw errors here, in case "el" is incorrect
+ if (el.is_array() && el.size() == 2) {
+ float bias;
+ if (el[1].is_number()) {
+ bias = el[1].get<float>();
+ } else if (el[1].is_boolean() && !el[1].get<bool>()) {
+ bias = -INFINITY;
+ } else {
+ continue;
+ }
+
+ if (el[0].is_number_integer()) {
+ llama_token tok = el[0].get<llama_token>();
+ if (tok >= 0 && tok < n_vocab) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ } else if (el[0].is_string()) {
+ auto toks = common_tokenize(model, el[0].get<std::string>(), false);
+ for (auto tok : toks) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ }
+ }
+ }
+ }
+ }
+
+ {
+ params.antiprompt.clear();
+
+ const auto & stop = data.find("stop");
+ if (stop != data.end() && stop->is_array()) {
+ for (const auto & word : *stop) {
+ if (!word.empty()) {
+ params.antiprompt.push_back(word);
+ }
+ }
+ }
+ }
+
+ {
+ const auto & samplers = data.find("samplers");
+ if (samplers != data.end()) {
+ if (samplers->is_array()) {
+ std::vector<std::string> sampler_names;
+ for (const auto & name : *samplers) {
+ if (name.is_string()) {
+ sampler_names.emplace_back(name);
+ }
+ }
+ params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
+ } else if (samplers->is_string()){
+ std::string sampler_string;
+ for (const auto & name : *samplers) {
+ sampler_string += name;
+ }
+ params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
+ }
+ } else {
+ params.sampling.samplers = defaults.sampling.samplers;
+ }
+ }
+
+ std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
+ params.oaicompat_model = json_value(data, "model", model_name);
+
+ return params;
+ }
+
+ // utility function
+ static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
+ std::unordered_set<int> ids(tasks.size());
+ for (size_t i = 0; i < tasks.size(); i++) {
+ ids.insert(tasks[i].id);
+ }
+ return ids;
+ }
+};
+
struct result_timings {
int32_t prompt_n = -1;
double prompt_ms;
double predicted_per_token_ms;
double predicted_per_second;
- json to_json() {
+ json to_json() const {
return {
{"prompt_n", prompt_n},
{"prompt_ms", prompt_ms},
uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
- // TODO: get rid of this json object and use to_json() instead
+ // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
+ // therefore, we use json to temporarily store the slot.to_json() result
json slots_data = json::array();
virtual json to_json() override {
int id;
int id_task = -1;
+ // only used for completion/embedding/infill/rerank
+ server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
+
llama_batch batch_spec = {};
llama_context * ctx = nullptr;
llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs;
- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
-
bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
- inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
+ task_type = SERVER_TASK_TYPE_COMPLETION;
generated_token_probs.clear();
}
+ bool is_non_causal() const {
+ return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
+ }
+
bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
{"n_ctx", n_ctx},
{"speculative", can_speculate()},
{"is_processing", is_processing()},
+ {"non_causal", is_non_causal()},
{"params", params.to_json()},
{"prompt", common_detokenize(ctx, prompt_tokens)},
{"next_token",
// Add a new task to the end of the queue
int post(server_task task, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks);
- if (task.id == -1) {
- task.id = id++;
- }
+ GGML_ASSERT(task.id != -1);
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
if (front) {
queue_tasks.push_front(std::move(task));
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
- slot_params defaults;
- defaults.sampling = params_base.sampling;
- defaults.speculative = params_base.speculative;
+ slot.reset();
+ slot.id_task = task.id;
+ slot.index = task.index;
+ slot.task_type = task.type;
+ slot.params = std::move(task.params);
+ slot.prompt_tokens = std::move(task.prompt_tokens);
- const auto & data = task.data;
-
- if (data.count("__oaicompat") != 0) {
- std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
- slot.params.oaicompat = true;
- slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false);
- slot.params.oaicompat_model = json_value(data, "model", model_name);
- slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
- } else {
- slot.params.oaicompat = false;
- }
-
-
- // enabling this will output extra debug information in the HTTP responses from the server
- slot.params.verbose = params_base.verbosity > 9;
- slot.params.timings_per_token = json_value(data, "timings_per_token", false);
-
- slot.params.stream = json_value(data, "stream", false);
- slot.params.cache_prompt = json_value(data, "cache_prompt", true);
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
- slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
- slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
- slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
- //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
- slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
-
- slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
- slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
- slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
- slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
- slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
- slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
- slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
- slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
- slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
- slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
- slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
- slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
- slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
- slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
- slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
- slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
- slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
- slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
- slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
- slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
- slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
- slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
- slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
- slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
-
- slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
- slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
- slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
-
- slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
- slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
- slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
-
- if (slot.params.sampling.dry_base < 1.0f) {
- slot.params.sampling.dry_base = defaults.sampling.dry_base;
- }
-
- // sequence breakers for DRY
- {
- // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
- // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
-
- if (data.contains("dry_sequence_breakers")) {
- slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
- if (slot.params.sampling.dry_sequence_breakers.empty()) {
- send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
- return false;
- }
- }
- }
-
- // process "json_schema" and "grammar"
- if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
- send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
- return false;
- }
- if (data.contains("json_schema") && !data.contains("grammar")) {
- try {
- auto schema = json_value(data, "json_schema", json::object());
- slot.params.sampling.grammar = json_schema_to_grammar(schema);
- } catch (const std::exception & e) {
- send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
- return false;
- }
- } else {
- slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
- }
+ SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
}
- {
- slot.params.sampling.logit_bias.clear();
-
- if (json_value(data, "ignore_eos", false) && has_eos_token) {
- slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
- }
-
- const auto & logit_bias = data.find("logit_bias");
- if (logit_bias != data.end() && logit_bias->is_array()) {
- const int n_vocab = llama_n_vocab(model);
- for (const auto & el : *logit_bias) {
- // TODO: we may want to throw errors here, in case "el" is incorrect
- if (el.is_array() && el.size() == 2) {
- float bias;
- if (el[1].is_number()) {
- bias = el[1].get<float>();
- } else if (el[1].is_boolean() && !el[1].get<bool>()) {
- bias = -INFINITY;
- } else {
- continue;
- }
-
- if (el[0].is_number_integer()) {
- llama_token tok = el[0].get<llama_token>();
- if (tok >= 0 && tok < n_vocab) {
- slot.params.sampling.logit_bias.push_back({tok, bias});
- }
- } else if (el[0].is_string()) {
- auto toks = common_tokenize(model, el[0].get<std::string>(), false);
- for (auto tok : toks) {
- slot.params.sampling.logit_bias.push_back({tok, bias});
- }
- }
- }
- }
- }
- }
-
- {
- slot.params.antiprompt.clear();
-
- const auto & stop = data.find("stop");
- if (stop != data.end() && stop->is_array()) {
- for (const auto & word : *stop) {
- if (!word.empty()) {
- slot.params.antiprompt.push_back(word);
- }
- }
- }
- }
-
- {
- const auto & samplers = data.find("samplers");
- if (samplers != data.end()) {
- if (samplers->is_array()) {
- std::vector<std::string> sampler_names;
- for (const auto & name : *samplers) {
- if (name.is_string()) {
- sampler_names.emplace_back(name);
- }
- }
- slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
- } else if (samplers->is_string()){
- std::string sampler_string;
- for (const auto & name : *samplers) {
- sampler_string += name;
- }
- slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
- }
- } else {
- slot.params.sampling.samplers = defaults.sampling.samplers;
- }
+ if (slot.params.ignore_eos && has_eos_token) {
+ slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
{
// Functions to create new task(s) and receive result(s)
//
- // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
- std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
- std::vector<server_task> tasks;
- auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
- SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
-
- server_task task;
- task.id = queue_tasks.get_new_id();
- task.inf_type = inf_type;
- task.type = SERVER_TASK_TYPE_INFERENCE;
- task.data = task_data;
- task.prompt_tokens = std::move(prompt_tokens);
- tasks.push_back(std::move(task));
- };
-
- static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
- if (!data.contains("prompt")) {
- throw std::runtime_error(error_msg);
- }
-
- // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
- bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
- switch (inf_type) {
- case SERVER_TASK_INF_TYPE_RERANK:
- {
- // prompts[0] is the question
- // the rest are the answers/documents
- GGML_ASSERT(tokenized_prompts.size() > 1);
- SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
- for (size_t i = 1; i < tokenized_prompts.size(); i++) {
- data["index"] = i - 1;
- auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
- create_task(data, tokens);
- }
- } break;
- case SERVER_TASK_INF_TYPE_INFILL:
- {
- SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
- data["index"] = i;
- auto tokens = format_infill(
- ctx,
- data.at("input_prefix"),
- data.at("input_suffix"),
- data.at("input_extra"),
- params_base.n_batch,
- params_base.n_predict,
- slots[0].n_ctx, // TODO: there should be a better way
- params_base.spm_infill,
- tokenized_prompts[i]
- );
- create_task(data, tokens);
- }
- } break;
- default:
- {
- SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
- data["index"] = i;
- create_task(data, tokenized_prompts[i]);
- }
- }
- }
-
- return tasks;
- }
-
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
- server_task task;
- task.type = SERVER_TASK_TYPE_CANCEL;
+ server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
cancel_tasks.push_back(task);
queue_results.remove_waiting_task_id(id_task);
queue_tasks.post(cancel_tasks, true);
}
- // receive the results from task(s) created by create_tasks_inference
+ // receive the results from task(s)
void receive_multi_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
result_handler(results);
}
- // receive the results from task(s) created by create_tasks_inference, in stream mode
+ // receive the results from task(s), in stream mode
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler,
void process_single_task(server_task task) {
switch (task.type) {
- case SERVER_TASK_TYPE_INFERENCE:
+ case SERVER_TASK_TYPE_COMPLETION:
+ case SERVER_TASK_TYPE_INFILL:
+ case SERVER_TASK_TYPE_EMBEDDING:
+ case SERVER_TASK_TYPE_RERANK:
{
- const int id_slot = json_value(task.data, "id_slot", -1);
+ const int id_slot = task.id_selected_slot;
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
break;
}
- slot->reset();
-
- slot->id_task = task.id;
- slot->inf_type = task.inf_type;
- slot->index = json_value(task.data, "index", 0);
- slot->prompt_tokens = std::move(task.prompt_tokens);
-
if (!launch_slot_with_task(*slot, task)) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
break;
res->n_decode_total = metrics.n_decode_total;
res->n_busy_slots_total = metrics.n_busy_slots_total;
- if (json_value(task.data, "reset_bucket", false)) {
+ if (task.metrics_reset_bucket) {
metrics.reset_bucket();
}
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
- std::string filename = task.data.at("filename");
- std::string filepath = task.data.at("filepath");
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
const int64_t t_start = ggml_time_us();
- std::string filename = task.data.at("filename");
- std::string filepath = task.data.at("filepath");
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0;
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
{
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
- server_task task;
- task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
- task.id_target = -1;
-
+ server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
+ task.id = queue_tasks.get_new_id();
queue_tasks.post(task);
}
continue;
}
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+ if (slot.is_non_causal()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
}
// non-causal tasks require to fit the entire prompt in the physical batch
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+ if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
// check that we are in the right batch_type, if not defer the slot
- const bool slot_type =
- slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
- slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
-
+ int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
}
if (slot.state == SLOT_STATE_DONE_PROMPT) {
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
+ if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
continue; // continue loop of slots
}
- if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+ if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
auto res_error = [](httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
- res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
+ res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response & res, const json & data) {
- res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
+ res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
res.status = 200;
};
}
// request slots data using task queue
- server_task task;
+ server_task task(SERVER_TASK_TYPE_METRICS);
task.id = ctx_server.queue_tasks.get_new_id();
- task.type = SERVER_TASK_TYPE_METRICS;
-
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task, true); // high-priority task
}
// request slots data using task queue
- server_task task;
+ server_task task(SERVER_TASK_TYPE_METRICS);
task.id = ctx_server.queue_tasks.get_new_id();
- task.id_target = -1;
- task.type = SERVER_TASK_TYPE_METRICS;
- task.data.push_back({{"reset_bucket", true}});
+ task.metrics_reset_bucket = true;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task, true); // high-priority task
}
std::string filepath = params.slot_save_path + filename;
- server_task task;
- task.type = SERVER_TASK_TYPE_SLOT_SAVE;
- task.data = {
- { "id_slot", id_slot },
- { "filename", filename },
- { "filepath", filepath },
- };
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.slot_action.slot_id = id_slot;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
- const int id_task = ctx_server.queue_tasks.post(task);
- ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
- server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
res_error(res, result->to_json());
}
std::string filepath = params.slot_save_path + filename;
- server_task task;
- task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
- task.data = {
- { "id_slot", id_slot },
- { "filename", filename },
- { "filepath", filepath },
- };
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.slot_action.slot_id = id_slot;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
- const int id_task = ctx_server.queue_tasks.post(task);
- ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
- server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
res_error(res, result->to_json());
};
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
- server_task task;
- task.type = SERVER_TASK_TYPE_SLOT_ERASE;
- task.data = {
- { "id_slot", id_slot },
- };
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.slot_action.slot_id = id_slot;
- const int id_task = ctx_server.queue_tasks.post(task);
- ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
- server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
res_error(res, result->to_json());
};
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
+ // this endpoint is publicly available, please only return what is safe to be exposed
json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
+ { "model_path", ctx_server.params_base.model },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
};
// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
- server_task_inf_type inf_type,
+ server_task_type type,
json & data,
httplib::Response & res,
- bool oai_compat = false) {
+ bool oaicompat = false,
+ bool oaicompat_chat = false) {
+ GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
+
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
- data["completion_id"] = gen_chatcmplid();
- std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
+ auto completion_id = gen_chatcmplid();
+ std::vector<server_task> tasks;
+
+ try {
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
+ tasks.reserve(tokenized_prompts.size());
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(type);
+
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // OAI-compat
+ task.params.oaicompat = oaicompat;
+ task.params.oaicompat_chat = oaicompat_chat;
+ task.params.oaicompat_cmpl_id = completion_id;
+ // oaicompat_model is already populated by params_from_json_cmpl
+
+ tasks.push_back(task);
+ }
+ } catch (const std::exception & e) {
+ res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
- const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
+ const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = result->to_json();
if (res_json.is_array()) {
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
- if (oai_compat) {
+ if (oaicompat) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
- return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
+ return handle_completions_generic(
+ SERVER_TASK_TYPE_COMPLETION,
+ data,
+ res,
+ /* oaicompat */ false,
+ /* oaicompat_chat */ false);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
- return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
+ return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
};
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
- data["__oaicompat_chat"] = true;
- return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
+ return handle_completions_generic(
+ SERVER_TASK_TYPE_COMPLETION,
+ data,
+ res,
+ /* oaicompat */ true,
+ /* oaicompat_chat */ true);
};
- const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
+ const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json models = {
{"object", "list"},
{"data", {
}}
};
- res.set_content(models.dump(), MIMETYPE_JSON);
+ res_ok(res, models);
};
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
json responses = json::array();
bool error = false;
{
- std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
+ std::vector<server_task> tasks;
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ tasks.push_back(task);
+ }
+
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
- : responses[0];
+ : responses.size() == 1 ? responses[0] : json(responses);
res_ok(res, root);
};
return;
}
- // construct prompt object: array of ["query", "doc0", "doc1", ...]
- json prompt;
- prompt.push_back(query);
- for (const auto & doc : documents) {
- prompt.push_back(doc);
- }
-
- LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
+ llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
// create and queue the task
json responses = json::array();
bool error = false;
{
- std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
+ std::vector<server_task> tasks;
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
+ tasks.reserve(tokenized_docs.size());
+ for (size_t i = 0; i < tokenized_docs.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+ task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
+ tasks.push_back(task);
+ }
+
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
}
}
- server_task task;
- task.type = SERVER_TASK_TYPE_SET_LORA;
- const int id_task = ctx_server.queue_tasks.post(task);
- ctx_server.queue_results.add_waiting_task_id(id_task);
+ server_task task(SERVER_TASK_TYPE_SET_LORA);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
- server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
res_error(res, result->to_json());