}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
- return false;
+ return true;
}
virtual int get_index() {
return -1;
queue_results.send(std::move(res));
}
- //
- // Functions to create new task(s) and receive result(s)
- //
-
- void cancel_tasks(const std::unordered_set<int> & id_tasks) {
- std::vector<server_task> cancel_tasks;
- cancel_tasks.reserve(id_tasks.size());
- for (const auto & id_task : id_tasks) {
- SRV_WRN("cancel task, id_task = %d\n", id_task);
-
- server_task task(SERVER_TASK_TYPE_CANCEL);
- task.id_target = id_task;
- queue_results.remove_waiting_task_id(id_task);
- cancel_tasks.push_back(std::move(task));
- }
- // push to beginning of the queue, so it has highest priority
- queue_tasks.post(std::move(cancel_tasks), true);
- }
-
- // receive the results from task(s)
- void receive_multi_results(
- const std::unordered_set<int> & id_tasks,
- const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
- const std::function<void(json)> & error_handler,
- const std::function<bool()> & is_connection_closed) {
- std::vector<server_task_result_ptr> results(id_tasks.size());
- for (int i = 0; i < (int)id_tasks.size(); i++) {
- server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
-
- if (is_connection_closed()) {
- cancel_tasks(id_tasks);
- return;
- }
-
- if (result == nullptr) {
- i--; // retry
- continue;
- }
-
- if (result->is_error()) {
- error_handler(result->to_json());
- cancel_tasks(id_tasks);
- return;
- }
-
- GGML_ASSERT(
- dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
- || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
- || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
- );
- const size_t idx = result->get_index();
- GGML_ASSERT(idx < results.size() && "index out of range");
- results[idx] = std::move(result);
- }
- result_handler(results);
- }
-
- // receive the results from task(s), in stream mode
- void receive_cmpl_results_stream(
- const std::unordered_set<int> & id_tasks,
- const std::function<bool(server_task_result_ptr&)> & result_handler,
- const std::function<void(json)> & error_handler,
- const std::function<bool()> & is_connection_closed) {
- size_t n_finished = 0;
- while (true) {
- server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
-
- if (is_connection_closed()) {
- cancel_tasks(id_tasks);
- return;
- }
-
- if (result == nullptr) {
- continue; // retry
- }
-
- if (result->is_error()) {
- error_handler(result->to_json());
- cancel_tasks(id_tasks);
- return;
- }
-
- GGML_ASSERT(
- dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
- || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
- );
- if (!result_handler(result)) {
- cancel_tasks(id_tasks);
- break;
- }
-
- if (result->is_stop()) {
- if (++n_finished == id_tasks.size()) {
- break;
- }
- }
- }
- }
-
//
// Functions to process the task
//
}
};
+// generator-like API for server responses, support pooling connection state and aggregating results
+struct server_response_reader {
+ std::unordered_set<int> id_tasks;
+ server_context & ctx_server;
+ size_t received_count = 0;
+ bool cancelled = false;
+
+ server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {}
+ ~server_response_reader() {
+ stop();
+ }
+
+ void post_tasks(std::vector<server_task> && tasks) {
+ id_tasks = server_task::get_list_id(tasks);
+ ctx_server.queue_results.add_waiting_tasks(tasks);
+ ctx_server.queue_tasks.post(std::move(tasks));
+ }
+
+ bool has_next() {
+ return !cancelled && received_count < id_tasks.size();
+ }
+
+ // return nullptr if should_stop() is true before receiving a result
+ // note: if one error is received, it will stop further processing and return error result
+ server_task_result_ptr next(const std::function<bool()> & should_stop) {
+ while (true) {
+ server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+ if (result == nullptr) {
+ // timeout, check stop condition
+ if (should_stop()) {
+ SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
+ return nullptr;
+ }
+ } else {
+ if (result->is_error()) {
+ stop(); // cancel remaining tasks
+ SRV_DBG("%s", "received error result, stopping further processing\n");
+ return result;
+ }
+ if (result->is_stop()) {
+ received_count++;
+ }
+ return result;
+ }
+ }
+
+ // should not reach here
+ }
+
+ struct batch_response {
+ bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
+ std::vector<server_task_result_ptr> results;
+ server_task_result_ptr error; // nullptr if no error
+ };
+
+ batch_response wait_for_all(const std::function<bool()> & should_stop) {
+ batch_response batch_res;
+ batch_res.results.resize(id_tasks.size());
+ while (has_next()) {
+ auto res = next(should_stop);
+ if (res == nullptr) {
+ batch_res.is_terminated = true;
+ return batch_res;
+ }
+ if (res->is_error()) {
+ batch_res.error = std::move(res);
+ return batch_res;
+ }
+ const size_t idx = res->get_index();
+ GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
+ GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
+ batch_res.results[idx] = std::move(res);
+ }
+ return batch_res;
+ }
+
+ void stop() {
+ ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
+ if (has_next() && !cancelled) {
+ // if tasks is not finished yet, cancel them
+ cancelled = true;
+ std::vector<server_task> cancel_tasks;
+ cancel_tasks.reserve(id_tasks.size());
+ for (const auto & id_task : id_tasks) {
+ SRV_WRN("cancel task, id_task = %d\n", id_task);
+ server_task task(SERVER_TASK_TYPE_CANCEL);
+ task.id_target = id_task;
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ cancel_tasks.push_back(std::move(task));
+ }
+ // push to beginning of the queue, so it has highest priority
+ ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
+ } else {
+ SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
+ }
+ }
+};
+
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip GH copilot requests when using default port
if (req.path == "/v1/health") {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
auto completion_id = gen_chatcmplid();
- std::unordered_set<int> task_ids;
+ // need to store the reader as a pointer, so that it won't be destroyed when the handle returns
+ // use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
+ const auto rd = std::make_shared<server_response_reader>(ctx_server);
+
try {
std::vector<server_task> tasks;
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
- const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
- auto n_prompt_tokens = inputs[i].size();
- if (n_prompt_tokens >= n_ctx_slot) {
- json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
- error_data["n_prompt_tokens"] = n_prompt_tokens;
- error_data["n_ctx"] = n_ctx_slot;
- res_error(res, error_data);
- return;
- }
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
tasks.push_back(std::move(task));
}
- task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ rd->post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
return;
bool stream = json_value(data, "stream", false);
if (!stream) {
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
- if (results.size() == 1) {
- // single result
- res_ok(res, results[0]->to_json());
- } else {
- // multiple results (multitask)
- json arr = json::array();
- for (auto & res : results) {
- arr.push_back(res->to_json());
- }
- res_ok(res, arr);
+ // non-stream, wait for the results
+ auto all_results = rd->wait_for_all(is_connection_closed);
+ if (all_results.is_terminated) {
+ return; // connection is closed
+ } else if (all_results.error) {
+ res_error(res, all_results.error->to_json());
+ return;
+ } else {
+ json arr = json::array();
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
+ arr.push_back(res->to_json());
}
- }, [&](const json & error_data) {
- res_error(res, error_data);
- }, is_connection_closed);
+ // if single request, return single object instead of array
+ res_ok(res, arr.size() == 1 ? arr[0] : arr);
+ }
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
- const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
- ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
- json res_json = result->to_json();
- if (res_json.is_array()) {
- for (const auto & res : res_json) {
- if (!server_sent_event(sink, res)) {
- // sending failed (HTTP connection closed), cancel the generation
- return false;
- }
- }
- return true;
- } else {
- return server_sent_event(sink, res_json);
+ // in streaming mode, the first error must be treated as non-stream response
+ // this is to match the OAI API behavior
+ // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
+ server_task_result_ptr first_result = rd->next(is_connection_closed);
+ if (first_result == nullptr) {
+ return; // connection is closed
+ } else if (first_result->is_error()) {
+ res_error(res, first_result->to_json());
+ return;
+ } else {
+ GGML_ASSERT(
+ dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
+ || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
+ );
+ }
+
+ // next responses are streamed
+ json first_result_json = first_result->to_json();
+ const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool {
+ // flush the first result as it's not an error
+ if (!first_result_json.empty()) {
+ if (!server_sent_event(sink, first_result_json)) {
+ sink.done();
+ return false; // sending failed, go to on_complete()
}
- }, [&](const json & error_data) {
- server_sent_event(sink, json{{"error", error_data}});
- }, [&sink]() {
- // note: do not use req.is_connection_closed here because req is already destroyed
- return !sink.is_writable();
- });
- if (oaicompat != OAICOMPAT_TYPE_NONE) {
- static const std::string ev_done = "data: [DONE]\n\n";
- sink.write(ev_done.data(), ev_done.size());
+ first_result_json.clear(); // mark as sent
}
- sink.done();
- return false;
+
+ // receive subsequent results
+ auto result = rd->next([&sink]{ return !sink.is_writable(); });
+ if (result == nullptr) {
+ sink.done();
+ return false; // connection is closed, go to on_complete()
+ }
+
+ // send the results
+ json res_json = result->to_json();
+ bool ok = false;
+ if (result->is_error()) {
+ ok = server_sent_event(sink, json {{ "error", result->to_json() }});
+ sink.done();
+ return false; // go to on_complete()
+ } else {
+ GGML_ASSERT(
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+ );
+ ok = server_sent_event(sink, res_json);
+ }
+
+ if (!ok) {
+ sink.done();
+ return false; // sending failed, go to on_complete()
+ }
+
+ // check if there is more data
+ if (!rd->has_next()) {
+ if (oaicompat != OAICOMPAT_TYPE_NONE) {
+ static const std::string ev_done = "data: [DONE]\n\n";
+ sink.write(ev_done.data(), ev_done.size());
+ }
+ sink.done();
+ return false; // no more data, go to on_complete()
+ }
+
+ // has next data, continue
+ return true;
};
- auto on_complete = [task_ids, &ctx_server] (bool) {
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ auto on_complete = [rd](bool) {
+ rd->stop();
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
// create and queue the task
json responses = json::array();
- bool error = false;
- std::unordered_set<int> task_ids;
+ server_response_reader rd(ctx_server);
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
tasks.push_back(std::move(task));
}
-
- task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ rd.post_tasks(std::move(tasks));
}
- // get the result
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
- for (auto & res : results) {
+ // wait for the results
+ auto all_results = rd.wait_for_all(req.is_connection_closed);
+
+ // collect results
+ if (all_results.is_terminated) {
+ return; // connection is closed
+ } else if (all_results.error) {
+ res_error(res, all_results.error->to_json());
+ return;
+ } else {
+ for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
- }, [&](const json & error_data) {
- res_error(res, error_data);
- error = true;
- }, req.is_connection_closed);
-
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
-
- if (error) {
- return;
}
// write JSON response
// create and queue the task
json responses = json::array();
- bool error = false;
- std::unordered_set<int> task_ids;
+ server_response_reader rd(ctx_server);
{
std::vector<server_task> tasks;
tasks.reserve(documents.size());
task.tokens = std::move(tmp);
tasks.push_back(std::move(task));
}
-
- task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ rd.post_tasks(std::move(tasks));
}
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
- for (auto & res : results) {
+ // wait for the results
+ auto all_results = rd.wait_for_all(req.is_connection_closed);
+
+ // collect results
+ if (all_results.is_terminated) {
+ return; // connection is closed
+ } else if (all_results.error) {
+ res_error(res, all_results.error->to_json());
+ return;
+ } else {
+ for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
- }, [&](const json & error_data) {
- res_error(res, error_data);
- error = true;
- }, req.is_connection_closed);
-
- if (error) {
- return;
}
// write JSON response