json data;
bool infill_mode = false;
bool embedding_mode = false;
+ int multitask_id = -1;
};
struct task_result {
int id;
+ int multitask_id = -1;
bool stop;
bool error;
json result_json;
};
+struct task_multi {
+ int id;
+ std::set<int> subtasks_remaining{};
+ std::vector<task_result> results{};
+};
+
// TODO: can become bool if we can't find use of more states
enum slot_state
{
double t_prompt_processing; // ms
double t_token_generation; // ms
+ // multitasks
+ int multitask_id = -1;
+
void reset() {
num_prompt_tokens = 0;
generated_text = "";
std::vector<task_server> queue_tasks;
std::vector<task_result> queue_results;
- std::mutex mutex_tasks;
+ std::vector<task_multi> queue_multitasks;
+ std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::mutex mutex_results;
~llama_server_context()
return slot.images.size() > 0;
}
- void send_error(int id, std::string error)
+ void send_error(task_server& task, std::string error)
{
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
- res.id = id;
+ res.id = task.id;
+ res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error } };
queue_results.push_back(res);
}
+ void add_multi_task(int id, std::vector<int>& sub_ids)
+ {
+ std::lock_guard<std::mutex> lock(mutex_tasks);
+ task_multi multi;
+ multi.id = id;
+ std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
+ queue_multitasks.push_back(multi);
+ }
+
+ void update_multi_task(int multitask_id, int subtask_id, task_result& result)
+ {
+ std::lock_guard<std::mutex> lock(mutex_tasks);
+ for (auto& multitask : queue_multitasks)
+ {
+ if (multitask.id == multitask_id)
+ {
+ multitask.subtasks_remaining.erase(subtask_id);
+ multitask.results.push_back(result);
+ }
+ }
+ }
+
json get_model_props()
{
return get_formated_generation(slots[0]);
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
+ res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = false;
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
+ res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;
res.result_json["model"] = slot.oaicompat_model;
}
+ // parent multitask, if any, needs to be updated
+ if (slot.multitask_id != -1)
+ {
+ update_multi_task(slot.multitask_id, slot.task_id, res);
+ }
+
queue_results.push_back(res);
}
std::lock_guard<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
+ res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;
queue_results.push_back(res);
}
- int request_completion(json data, bool infill, bool embedding)
+ int request_completion(json data, bool infill, bool embedding, int multitask_id)
{
- std::lock_guard<std::mutex> lock(mutex_tasks);
+ std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.target_id = 0;
task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = COMPLETION_TASK;
+ task.multitask_id = multitask_id;
+
+ // when a completion task's prompt array is not a singleton, we split it into multiple requests
+ if (task.data.at("prompt").size() > 1)
+ {
+ lock.unlock(); // entering new func scope
+ return split_multiprompt_task(task);
+ }
+
+ // otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task);
return task.id;
}
for (int i = 0; i < (int) queue_results.size(); i++)
{
+ // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
+ if (queue_results[i].multitask_id == task_id)
+ {
+ update_multi_task(task_id, queue_results[i].id, queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ continue;
+ }
+
if (queue_results[i].id == task_id)
{
+ assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
queue_tasks.push_back(task);
}
+ int split_multiprompt_task(task_server& multiprompt_task)
+ {
+ auto prompt_count = multiprompt_task.data.at("prompt").size();
+ assert(prompt_count > 1);
+
+ int multitask_id = id_gen++;
+ std::vector<int> subtask_ids(prompt_count);
+ for (int i = 0; i < prompt_count; i++)
+ {
+ json subtask_data = multiprompt_task.data;
+ subtask_data["prompt"] = subtask_data["prompt"][i];
+
+ // subtasks inherit everything else (infill mode, embedding mode, etc.)
+ subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
+ }
+
+ // queue up the multitask so we can track its subtask progression
+ add_multi_task(multitask_id, subtask_ids);
+ return multitask_id;
+ }
+
void process_tasks()
{
std::lock_guard<std::mutex> lock(mutex_tasks);
{
LOG_TEE("slot unavailable\n");
// send error result
- send_error(task.id, "slot unavailable");
+ send_error(task, "slot unavailable");
return;
}
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
+ slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data))
{
// send error result
- send_error(task.id, "internal_error");
+ send_error(task, "internal_error");
break;
}
} break;
} break;
}
}
+
+ // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
+ auto queue_iterator = queue_multitasks.begin();
+ while (queue_iterator != queue_multitasks.end())
+ {
+ if (queue_iterator->subtasks_remaining.empty())
+ {
+ // all subtasks done == multitask is done
+ task_result aggregate_result;
+ aggregate_result.id = queue_iterator->id;
+ aggregate_result.stop = true;
+ aggregate_result.error = false;
+
+ // collect json results into one json result
+ std::vector<json> result_jsons;
+ for (auto& subres : queue_iterator->results)
+ {
+ result_jsons.push_back(subres.result_json);
+ aggregate_result.error = aggregate_result.error && subres.error;
+ }
+ aggregate_result.result_json = json{ "results", result_jsons };
+
+ std::lock_guard<std::mutex> lock(mutex_results);
+ queue_results.push_back(aggregate_result);
+
+ queue_iterator = queue_multitasks.erase(queue_iterator);
+ }
+ else
+ {
+ ++queue_iterator;
+ }
+ }
}
bool update_slots() {
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, false, false);
+ const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
{
json data = oaicompat_completion_params_parse(json::parse(req.body));
- const int task_id = llama.request_completion(data, false, false);
+ const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, true, false);
+ const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
{
prompt = "";
}
- const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
+ const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json");
});