return 0;
}
+server_tokens server_tokens::clone() const {
+ server_tokens res;
+ res.has_mtmd = has_mtmd;
+ res.tokens = tokens;
+ for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+ size_t idx = it->first;
+ const mtmd::input_chunk_ptr & chunk = it->second;
+ res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
+ }
+ return res;
+}
+
//
// tokenizer and input processing utils
//
llama_params["stop"] = json_value(body, "stop", json::array());
}
- // Handle "n" field
- int n_choices = json_value(body, "n", 1);
- if (n_choices != 1) {
- throw std::runtime_error("Only one completion choice is allowed");
- }
-
// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
llama_params["chat_parser"] = chat_params.parser;
}
- // Handle "n" field
- int n_choices = json_value(body, "n", 1);
- if (n_choices != 1) {
- throw std::invalid_argument("Only one completion choice is allowed");
- }
-
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
- SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
+ SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
+ SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
generated_token_probs.push_back(token);
}
+ // note: a slot can also be either a parent or a child
+ bool is_parent() const {
+ return is_processing() && task->n_children > 0;
+ }
+
+ bool is_child() const {
+ return is_processing() && task->id_parent >= 0;
+ }
+
void release() {
if (is_processing()) {
GGML_ASSERT(task);
return res;
}
+
+ void copy_state_to(server_slot & other) const {
+ llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
+ llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
+ other.n_decoded = n_decoded;
+ other.n_remaining = n_remaining;
+ other.i_batch = i_batch;
+ other.n_prompt_tokens_cache = n_prompt_tokens_cache;
+ other.n_prompt_tokens_processed = n_prompt_tokens_processed;
+ other.prompt = prompt.clone();
+ }
};
slot.task = std::make_unique<const server_task>(std::move(task));
- slot.state = SLOT_STATE_STARTED;
+ slot.state = slot.is_child()
+ ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
+ : SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
GGML_ABORT("not supported by multimodal");
}
+ if (slot.is_parent() || slot.is_child()) {
+ send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
+ slot.release();
+ continue;
+ }
+
// Shift context
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
n_batch = llama_n_batch(ctx);
for (auto & slot : slots) {
+ // may need to copy state to other slots
+ if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
+ std::vector<server_slot *> child_slots;
+ for (auto & other : slots) {
+ if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
+ child_slots.push_back(&other);
+ }
+ }
+
+ // we can only proceed if all child slots are having the correct tasks
+ if (child_slots.size() == slot.task->n_children) {
+ // copy state to the child slots
+ for (auto & child : child_slots) {
+ SLT_INF(slot, "copying state to child %d\n", child->id);
+ slot.copy_state_to(*child);
+ child->state = SLOT_STATE_DONE_PROMPT;
+ }
+ }
+ }
+
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
+ int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
- task.index = i;
+ task.index = idx++;
task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
+ if (task.params.n_cmpl > 1) {
+ task.n_children = task.params.n_cmpl - 1;
+ for (size_t j = 0; j < task.n_children; j++) {
+ server_task child = task.create_child(
+ task.id,
+ ctx_server.queue_tasks.get_new_id(),
+ idx++);
+ states.push_back(child.params.oaicompat_chat_syntax);
+ tasks.push_back(std::move(child));
+ }
+ }
+
tasks.push_back(std::move(task));
}
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json());
}
- // if single request, return single object instead of array
- res->ok(arr.size() == 1 ? arr[0] : arr);
+ GGML_ASSERT(!arr.empty() && "empty results");
+ if (arr.size() == 1) {
+ // if single request, return single object instead of array
+ res->ok(arr[0]);
+ } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
+ // if multiple results in OAI format, we need to re-format them
+ json & choices = arr[0]["choices"];
+ for (size_t i = 1; i < arr.size(); i++) {
+ choices.push_back(std::move(arr[i]["choices"][0]));
+ }
+ res->ok(arr[0]);
+ } else {
+ // multi-results, non-OAI compat
+ res->ok(arr);
+ }
}
} else {
// in streaming mode, the first error must be treated as non-stream response
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
}
}
+ if (params.n_cmpl > params_base.n_parallel) {
+ throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
+ }
+
return params;
}
json choice {
{"finish_reason", finish_reason},
- {"index", 0},
+ {"index", index},
{"message", msg.to_json_oaicompat<json>()},
};
{"choices", json::array({
json {
{"finish_reason", nullptr},
- {"index", 0},
+ {"index", index},
{"delta", delta},
},
})},