std::string generated_text;
llama_tokens generated_tokens;
- common_chat_msg chat_msg;
-
std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true;
llama_token sampled;
- common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- std::vector<std::string> generated_tool_call_ids;
-
// stats
size_t n_sent_text = 0; // number of sent text character
stop = STOP_TYPE_NONE;
stopping_word = "";
n_sent_text = 0;
- chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
- chat_msg = {};
json_schema = json();
- generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
return timings;
}
- const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
- GGML_ASSERT(task);
-
- auto previous_msg = chat_msg;
- SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
- auto new_msg = common_chat_parse(
- generated_text,
- /* is_partial= */ stop != STOP_TYPE_EOS,
- task->params.oaicompat_chat_syntax);
- if (!new_msg.empty()) {
- new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
- chat_msg = new_msg;
- diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
- }
- return chat_msg;
- }
-
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
GGML_ASSERT(task);
} else {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };
-
- slot.update_chat_msg(res->oaicompat_msg_diffs);
}
res->n_decoded = slot.n_decoded;
res->id_slot = slot.id;
res->index = slot.task->index;
- res->content = slot.generated_text;
- res->tokens = std::move(slot.generated_tokens);
+ // in stream mode, content and tokens are already in last partial chunk
+ if (slot.task->params.stream) {
+ res->content = "";
+ res->tokens = llama_tokens{};
+ } else {
+ res->content = std::move(slot.generated_text);
+ res->tokens = std::move(slot.generated_tokens);
+ }
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields);
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
- res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
try {
std::vector<server_task> tasks;
+ // tracking generation state and partial tool calls
+ std::vector<task_result_state> states;
+
const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
+ states.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
+ states.push_back(task.params.oaicompat_chat_syntax);
tasks.push_back(std::move(task));
}
+ rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
}
-
} else {
// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
}
// next responses are streamed
+ // to be sent immediately
+ json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
- res->data = format_anthropic_sse(first_result->to_json());
+ res->data = format_anthropic_sse(first_result_json);
} else {
- res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
+ res->data = format_oai_sse(first_result_json);
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
- if (should_stop()) {
- SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
- return false; // should_stop condition met
- }
-
- if (!res_this->data.empty()) {
- // flush the first chunk
- output = std::move(res_this->data);
- res_this->data.clear();
- return true;
- }
-
- server_response_reader & rd = res_this->rd;
-
- // check if there is more data
- if (!rd.has_next()) {
- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
- // Anthropic doesn't send [DONE], message_stop was already sent
- output = "";
- } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
- output = "data: [DONE]\n\n";
- } else {
- output = "";
- }
- SRV_DBG("%s", "all results received, terminating stream\n");
- return false; // no more data, terminate
- }
-
- // receive subsequent results
- auto result = rd.next(should_stop);
- if (result == nullptr) {
- SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
- return false; // should_stop condition met
- }
-
- // send the results
- json res_json = result->to_json();
- if (result->is_error()) {
+ static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
- output = format_anthropic_sse({
+ return format_anthropic_sse({
{"event", "error"},
{"data", res_json},
});
} else {
- output = format_oai_sse(json {{ "error", res_json }});
+ return format_oai_sse(json {{ "error", res_json }});
}
- SRV_DBG("%s", "error received during streaming, terminating stream\n");
- return false; // terminate on error
- } else {
- GGML_ASSERT(
- dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
- || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
- );
- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
- output = format_anthropic_sse(res_json);
+ };
+
+ try {
+ if (should_stop()) {
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+ return false; // should_stop condition met
+ }
+
+ if (!res_this->data.empty()) {
+ // flush the first chunk
+ output = std::move(res_this->data);
+ res_this->data.clear();
+ return true;
+ }
+
+ server_response_reader & rd = res_this->rd;
+
+ // check if there is more data
+ if (!rd.has_next()) {
+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
+ // Anthropic doesn't send [DONE], message_stop was already sent
+ output = "";
+ } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
+ output = "data: [DONE]\n\n";
+ } else {
+ output = "";
+ }
+ SRV_DBG("%s", "all results received, terminating stream\n");
+ return false; // no more data, terminate
+ }
+
+ // receive subsequent results
+ auto result = rd.next(should_stop);
+ if (result == nullptr) {
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+ return false; // should_stop condition met
+ }
+
+ // send the results
+ if (result->is_error()) {
+ json res_json = result->to_json();
+ output = format_error(res_type, res_json);
+ SRV_DBG("%s", "error received during streaming, terminating stream\n");
+ return false; // terminate on error
} else {
- output = format_oai_sse(res_json);
+ GGML_ASSERT(
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+ );
+ json res_json = result->to_json();
+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
+ output = format_anthropic_sse(res_json);
+ } else {
+ output = format_oai_sse(res_json);
+ }
}
- }
- // has next data, continue
- return true;
+ // has next data, continue
+ return true;
+
+ } catch (const std::exception & e) {
+ json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
+ output = format_error(res_type, error_json);
+
+ // terminate on exception
+ return false;
+ }
};
}
// server_task_result_cmpl_final
//
json server_task_result_cmpl_final::to_json() {
+ GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
json server_task_result_cmpl_final::to_json_non_oaicompat() {
json res = json {
{"index", index},
- {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"tokens", stream ? llama_tokens {} : tokens},
+ {"content", content},
+ {"tokens", tokens},
{"id_slot", id_slot},
{"stop", true},
{"model", oaicompat_model},
json res = json {
{"choices", json::array({
json{
- {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"text", content},
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
return res;
}
+common_chat_msg task_result_state::update_chat_msg(
+ const std::string & text_added,
+ bool is_partial,
+ std::vector<common_chat_msg_diff> & diffs) {
+ generated_text += text_added;
+ auto msg_prv_copy = chat_msg;
+ SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
+ auto new_msg = common_chat_parse(
+ generated_text,
+ is_partial,
+ oaicompat_chat_syntax);
+ if (!new_msg.empty()) {
+ new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
+ chat_msg = new_msg;
+ diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
+ }
+ return chat_msg;
+}
+
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
// server_task_result_cmpl_partial
//
json server_task_result_cmpl_partial::to_json() {
+ GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
json to_json() const;
};
+// struct for tracking the state of a task (e.g., for streaming)
+struct task_result_state {
+ // tracking diffs for partial tool calls
+ std::vector<common_chat_msg_diff> diffs;
+ common_chat_syntax oaicompat_chat_syntax;
+ common_chat_msg chat_msg;
+ std::string generated_text; // append new chunks of generated text here
+ std::vector<std::string> generated_tool_call_ids;
+
+ task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
+ : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
+
+ // parse partial tool calls and update the internal state
+ common_chat_msg update_chat_msg(
+ const std::string & text_added,
+ bool is_partial,
+ std::vector<common_chat_msg_diff> & diffs);
+};
+
struct server_task_result {
int id = -1;
int id_slot = -1;
virtual int get_index() {
return -1;
}
+ virtual void update(task_result_state &) {
+ // only used by server_task_result_cmpl_*
+ }
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
- common_chat_msg oaicompat_msg;
+ common_chat_msg oaicompat_msg; // to be populated by update()
- std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+ std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
+ bool is_updated = false;
virtual int get_index() override {
return index;
virtual json to_json() override;
+ virtual void update(task_result_state & state) override {
+ is_updated = true;
+ oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
+ }
+
json to_json_non_oaicompat();
json to_json_oaicompat();
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
- std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+ std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
+ bool is_updated = false;
virtual int get_index() override {
return index;
virtual json to_json() override;
+ virtual void update(task_result_state & state) override {
+ is_updated = true;
+ state.update_chat_msg(content, true, oaicompat_msg_diffs);
+ }
+
json to_json_non_oaicompat();
json to_json_oaicompat();