--- /dev/null
+#pragma once
+
+#include <string>
+#include <vector>
+#include <set>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+
+#include "json.hpp"
+#include "utils.hpp"
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
+
+using json = nlohmann::json;
+
+inline static json oaicompat_completion_params_parse(
+ const json &body /* openai api json semantics */)
+{
+ json llama_params;
+
+ llama_params["__oaicompat"] = true;
+
+ // Map OpenAI parameters to llama.cpp parameters
+ //
+ // For parameters that are defined by the OpenAI documentation (e.g.
+ // temperature), we explicitly specify OpenAI's intended default; we
+ // need to do that because sometimes OpenAI disagrees with llama.cpp
+ //
+ // https://platform.openai.com/docs/api-reference/chat/create
+ llama_sampling_params default_sparams;
+ llama_params["model"] = json_value(body, "model", std::string("unknown"));
+ llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
+ llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
+ llama_params["temperature"] = json_value(body, "temperature", 0.0);
+ llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
+ llama_params["top_p"] = json_value(body, "top_p", 1.0);
+ llama_params["n_predict"] = json_value(body, "max_tokens", -1);
+ llama_params["logit_bias"] = json_value(body, "logit_bias",json::object());
+ llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
+ llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
+ llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
+ llama_params["stream"] = json_value(body, "stream", false);
+ llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
+ llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
+ llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
+ llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
+ llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
+ llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
+ llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
+ llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
+
+ if (body.count("grammar") != 0) {
+ llama_params["grammar"] = json_value(body, "grammar", json::object());
+ }
+
+ // Handle 'stop' field
+ if (body.contains("stop") && body["stop"].is_string()) {
+ llama_params["stop"] = json::array({body["stop"].get<std::string>()});
+ } else {
+ llama_params["stop"] = json_value(body, "stop", json::array());
+ }
+
+ // Ensure there is ChatML-specific end sequence among stop words
+ llama_params["stop"].push_back("<|im_end|>");
+
+ return llama_params;
+}
+
+inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
+{
+ json result = response.result_json;
+
+ bool stopped_word = result.count("stopped_word") != 0;
+ bool stopped_eos = json_value(result, "stopped_eos", false);
+ int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
+ int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
+ std::string content = json_value(result, "content", std::string(""));
+
+ std::string finish_reason = "length";
+ if (stopped_word || stopped_eos) {
+ finish_reason = "stop";
+ }
+
+ json choices =
+ streaming ? json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}}})
+ : json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", json{{"content", content},
+ {"role", "assistant"}}}}});
+
+ std::time_t t = std::time(0);
+
+ json res =
+ json{{"choices", choices},
+ {"created", t},
+ {"model",
+ json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
+ {"usage",
+ json{{"completion_tokens", num_tokens_predicted},
+ {"prompt_tokens", num_prompt_tokens},
+ {"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
+ {"id", gen_chatcmplid()}};
+
+ if (server_verbose) {
+ res["__verbose"] = result;
+ }
+
+ if (result.contains("completion_probabilities")) {
+ res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
+ }
+
+ return res;
+}
+
+// return value is vector as there is one case where we might need to generate two responses
+inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
+ json result = response.result_json;
+
+ if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
+ return std::vector<json>({response.result_json});
+ }
+
+ bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
+ std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+
+ bool stopped_word = json_value(result, "stopped_word", false);
+ bool stopped_eos = json_value(result, "stopped_eos", false);
+ bool stopped_limit = json_value(result, "stopped_limit", false);
+ std::string content = json_value(result, "content", std::string(""));
+
+ std::string finish_reason;
+ if (stopped_word || stopped_eos) {
+ finish_reason = "stop";
+ }
+ if (stopped_limit) {
+ finish_reason = "length";
+ }
+
+ std::time_t t = std::time(0);
+
+ json choices;
+
+ if (!finish_reason.empty()) {
+ choices = json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}}});
+ } else {
+ if (first) {
+ if (content.empty()) {
+ choices = json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}}});
+ } else {
+ // We have to send this as two updates to conform to openai behavior
+ json initial_ret = json{{"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{
+ {"role", "assistant"}
+ }}}})},
+ {"created", t},
+ {"id", gen_chatcmplid()},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}};
+
+ json second_ret = json{
+ {"choices", json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{
+ {"content", content}}}
+ }})},
+ {"created", t},
+ {"id", gen_chatcmplid()},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}};
+
+ return std::vector<json>({initial_ret, second_ret});
+ }
+ } else {
+ // Some idiosyncrasy in task processing logic makes several trailing calls
+ // with empty content, we ignore these at the calee site.
+ if (content.empty()) {
+ return std::vector<json>({json::object()});
+ }
+
+ choices = json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta",
+ json{
+ {"content", content},
+ }},
+ }});
+ }
+ }
+
+ json ret = json{{"choices", choices},
+ {"created", t},
+ {"id", gen_chatcmplid()},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}};
+
+ return std::vector<json>({ret});
+}
#include "common.h"
#include "llama.h"
#include "grammar-parser.h"
+#include "utils.hpp"
+#include "oai.hpp"
#include "../llava/clip.h"
#include <cstddef>
#include <thread>
-#include <mutex>
#include <chrono>
#include <condition_variable>
#include <atomic>
-#ifndef SERVER_VERBOSE
-#define SERVER_VERBOSE 1
-#endif
-
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
-
using json = nlohmann::json;
struct server_params
int32_t write_timeout = 600;
};
-static bool server_verbose = false;
-
-#if SERVER_VERBOSE != 1
-#define LOG_VERBOSE(MSG, ...)
-#else
-#define LOG_VERBOSE(MSG, ...) \
- do \
- { \
- if (server_verbose) \
- { \
- server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
- } \
- } while (0)
-#endif
-
-#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
-
-json oaicompat_completion_params_parse(const json &body);
-std::string format_chatml(std::vector<json> messages);
-
-
-//
-// base64 utils (TODO: move to common in the future)
-//
-
-static const std::string base64_chars =
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789+/";
-
-static inline bool is_base64(uint8_t c)
-{
- return (isalnum(c) || (c == '+') || (c == '/'));
-}
-
-static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
-{
- int i = 0;
- int j = 0;
- int in_ = 0;
-
- int in_len = encoded_string.size();
-
- uint8_t char_array_4[4];
- uint8_t char_array_3[3];
-
- std::vector<uint8_t> ret;
-
- while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
- {
- char_array_4[i++] = encoded_string[in_]; in_++;
- if (i == 4)
- {
- for (i = 0; i <4; i++)
- {
- char_array_4[i] = base64_chars.find(char_array_4[i]);
- }
-
- char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
-
- for (i = 0; (i < 3); i++)
- {
- ret.push_back(char_array_3[i]);
- }
- i = 0;
- }
- }
-
- if (i)
- {
- for (j = i; j <4; j++)
- {
- char_array_4[j] = 0;
- }
-
- for (j = 0; j <4; j++)
- {
- char_array_4[j] = base64_chars.find(char_array_4[j]);
- }
-
- char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
-
- for (j = 0; (j < i - 1); j++)
- {
- ret.push_back(char_array_3[j]);
- }
- }
-
- return ret;
-}
-
-//
-// parallel
-//
-
-enum server_state {
- SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
- SERVER_STATE_READY, // Server is ready and model is loaded
- SERVER_STATE_ERROR // An error occurred, load_model failed
-};
-
-enum task_type {
- TASK_TYPE_COMPLETION,
- TASK_TYPE_CANCEL,
-};
-
-struct task_server {
- int id;
- int target_id;
- task_type type;
- json data;
- bool infill_mode = false;
- bool embedding_mode = false;
- int multitask_id = -1;
-};
-
-struct task_result {
- int id;
- int multitask_id = -1;
- bool stop;
- bool error;
- json result_json;
-};
-
-struct task_multi {
- int id;
- std::set<int> subtasks_remaining{};
- std::vector<task_result> results{};
-};
-
-// TODO: can become bool if we can't find use of more states
-enum slot_state
-{
- IDLE,
- PROCESSING,
-};
-
-enum slot_command
-{
- NONE,
- LOAD_PROMPT,
- RELEASE,
-};
-
-struct slot_params
-{
- bool stream = true;
- bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
-
- uint32_t seed = -1; // RNG seed
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_predict = -1; // new tokens to predict
-
- std::vector<std::string> antiprompt;
-
- json input_prefix;
- json input_suffix;
-};
-
-struct slot_image
-{
- int32_t id;
-
- bool request_encode_image = false;
- float * image_embedding = nullptr;
- int32_t image_tokens = 0;
-
- clip_image_u8 * img_data;
-
- std::string prefix_prompt; // before of this image
-};
-
-// completion token output with probabilities
-struct completion_token_output
-{
- struct token_prob
- {
- llama_token tok;
- float prob;
- };
-
- std::vector<token_prob> probs;
- llama_token tok;
- std::string text_to_send;
-};
+bool server_verbose = false;
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
{
return ret;
}
-static void server_log(const char *level, const char *function, int line,
- const char *message, const nlohmann::ordered_json &extra)
-{
- nlohmann::ordered_json log
- {
- {"timestamp", time(nullptr)},
- {"level", level},
- {"function", function},
- {"line", line},
- {"message", message},
- };
-
- if (!extra.empty())
- {
- log.merge_patch(extra);
- }
-
- const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
- printf("%.*s\n", (int)str.size(), str.data());
- fflush(stdout);
-}
-
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{
return out;
}
-template <typename T>
-static T json_value(const json &body, const std::string &key, const T &default_value)
-{
- // Fallback null to default value
- return body.contains(key) && !body.at(key).is_null()
- ? body.value(key, default_value)
- : default_value;
-}
-
struct llama_client_slot
{
int id;
}
void release() {
- if (state == IDLE || state == PROCESSING)
+ if (state == PROCESSING)
{
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
command = RELEASE;
bool all_slots_are_idle = false;
bool add_bos_token = true;
- int32_t id_gen;
int32_t n_ctx; // total context for all clients / slots
// system prompt
// slots / clients
std::vector<llama_client_slot> slots;
- std::vector<task_server> queue_tasks;
- std::vector<task_result> queue_results;
- std::vector<task_multi> queue_multitasks;
- std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
- std::condition_variable condition_tasks;
- std::mutex mutex_results;
- std::condition_variable condition_results;
+ llama_server_queue queue_tasks;
+ llama_server_response queue_results;
~llama_server_context()
{
}
void initialize() {
- id_gen = 0;
-
// create slots
all_slots_are_idle = true;
void send_error(task_server& task, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
- std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error } };
- queue_results.push_back(res);
- condition_results.notify_all();
- }
-
- void add_multi_task(int id, std::vector<int>& sub_ids)
- {
- std::lock_guard<std::mutex> lock(mutex_tasks);
- task_multi multi;
- multi.id = id;
- std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
- queue_multitasks.push_back(multi);
- condition_tasks.notify_one();
- }
-
- void update_multi_task(int multitask_id, int subtask_id, task_result& result)
- {
- std::lock_guard<std::mutex> lock(mutex_tasks);
- for (auto& multitask : queue_multitasks)
- {
- if (multitask.id == multitask_id)
- {
- multitask.subtasks_remaining.erase(subtask_id);
- multitask.results.push_back(result);
- condition_tasks.notify_one();
- }
- }
+ queue_results.send(res);
}
json get_model_props()
void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
{
- std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.result_json["model"] = slot.oaicompat_model;
}
- queue_results.push_back(res);
- condition_results.notify_all();
+ queue_results.send(res);
}
void send_final_response(llama_client_slot &slot)
{
- std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.result_json["model"] = slot.oaicompat_model;
}
- queue_results.push_back(res);
- condition_results.notify_all();
-
- // done with results, unlock
- lock.unlock();
-
- // parent multitask, if any, needs to be updated
- if (slot.multitask_id != -1)
- {
- update_multi_task(slot.multitask_id, slot.task_id, res);
- }
+ queue_results.send(res);
}
void send_embedding(llama_client_slot &slot)
{
- std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
{"embedding", embedding },
};
}
- queue_results.push_back(res);
- condition_results.notify_all();
+ queue_results.send(res);
}
- int request_completion(json data, bool infill, bool embedding, int multitask_id)
+ void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{
- std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
- task.id = id_gen++;
+ task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{
- lock.unlock(); // entering new func scope
- return split_multiprompt_task(task);
+ split_multiprompt_task(task_id, task);
}
// otherwise, it's a single-prompt task, we actually queue it
- queue_tasks.push_back(task);
- condition_tasks.notify_one();
- return task.id;
- }
-
- task_result next_result(int task_id)
- {
- while (true)
- {
- std::unique_lock<std::mutex> lock(mutex_results);
- condition_results.wait(lock, [&]{
- return !queue_results.empty();
- });
-
- for (int i = 0; i < (int) queue_results.size(); i++)
- {
- // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
- if (queue_results[i].multitask_id == task_id)
- {
- update_multi_task(task_id, queue_results[i].id, queue_results[i]);
- queue_results.erase(queue_results.begin() + i);
- continue;
- }
-
- if (queue_results[i].id == task_id)
- {
- assert(queue_results[i].multitask_id == -1);
- task_result res = queue_results[i];
- queue_results.erase(queue_results.begin() + i);
- return res;
- }
- }
- }
-
- // never reached
- //return task_result{-1, false, false, {}};
+ queue_tasks.post(task);
}
// for multiple images processing
void request_cancel(int task_id)
{
- std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
- task.id = id_gen++;
task.type = TASK_TYPE_CANCEL;
task.target_id = task_id;
- queue_tasks.push_back(task);
- condition_tasks.notify_one();
+ queue_tasks.post(task);
}
- int split_multiprompt_task(task_server& multiprompt_task)
+ void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{
int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
- int multitask_id = id_gen++;
+ // generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
+ {
+ subtask_ids[i] = queue_tasks.get_new_id();
+ }
+
+ // queue up the multitask so we can track its subtask progression
+ queue_tasks.add_multitask(multitask_id, subtask_ids);
+
+ // add subtasks
+ for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
- subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
+ request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}
-
- // queue up the multitask so we can track its subtask progression
- add_multi_task(multitask_id, subtask_ids);
- return multitask_id;
}
- void process_tasks()
+ void process_single_task(task_server& task)
{
- std::unique_lock<std::mutex> lock(mutex_tasks);
- std::vector<task_server> deferred_tasks;
- while (!queue_tasks.empty())
+ switch (task.type)
{
- task_server task = queue_tasks.front();
- queue_tasks.erase(queue_tasks.begin());
- switch (task.type)
- {
- case TASK_TYPE_COMPLETION: {
- llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
- if (slot == nullptr)
- {
- // if no slot is available, we defer this task for processing later
- deferred_tasks.push_back(task);
+ case TASK_TYPE_COMPLETION: {
+ llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
+ if (slot == nullptr)
+ {
+ // if no slot is available, we defer this task for processing later
+ LOG_VERBOSE("no slot is available", {});
+ queue_tasks.defer(task);
+ break;
+ }
+
+ if (task.data.contains("system_prompt"))
+ {
+ if (!all_slots_are_idle) {
+ send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
+ process_system_prompt_data(task.data["system_prompt"]);
- if (task.data.contains("system_prompt"))
+ // reset cache_tokens for all slots
+ for (llama_client_slot &slot : slots)
{
- if (!all_slots_are_idle) {
- send_error(task, "system prompt can only be updated when all slots are idle");
- break;
- }
- process_system_prompt_data(task.data["system_prompt"]);
-
- // reset cache_tokens for all slots
- for (llama_client_slot &slot : slots)
- {
- slot.cache_tokens.clear();
- }
+ slot.cache_tokens.clear();
}
+ }
- slot->reset();
+ slot->reset();
- slot->infill = task.infill_mode;
- slot->embedding = task.embedding_mode;
- slot->task_id = task.id;
- slot->multitask_id = task.multitask_id;
+ slot->infill = task.infill_mode;
+ slot->embedding = task.embedding_mode;
+ slot->task_id = task.id;
+ slot->multitask_id = task.multitask_id;
- if (!launch_slot_with_data(slot, task.data))
+ if (!launch_slot_with_data(slot, task.data))
+ {
+ // send error result
+ send_error(task, "internal_error");
+ break;
+ }
+ } break;
+ case TASK_TYPE_CANCEL: { // release slot linked with the task id
+ for (auto & slot : slots)
+ {
+ if (slot.task_id == task.target_id)
{
- // send error result
- send_error(task, "internal_error");
+ slot.release();
break;
}
- } break;
- case TASK_TYPE_CANCEL: { // release slot linked with the task id
- for (auto & slot : slots)
- {
- if (slot.task_id == task.target_id)
- {
- slot.release();
- break;
- }
- }
- } break;
- }
+ }
+ } break;
+ case TASK_TYPE_NEXT_RESPONSE: {
+ // do nothing
+ } break;
}
+ }
- // add all the deferred tasks back the the queue
- for (task_server &task : deferred_tasks)
- {
- queue_tasks.push_back(task);
- }
+ void on_finish_multitask(task_multi& multitask)
+ {
+ // all subtasks done == multitask is done
+ task_result result;
+ result.id = multitask.id;
+ result.stop = true;
+ result.error = false;
- // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
- std::vector<task_result> agg_results;
- auto queue_iterator = queue_multitasks.begin();
- while (queue_iterator != queue_multitasks.end())
+ // collect json results into one json result
+ std::vector<json> result_jsons;
+ for (auto& subres : multitask.results)
{
- if (queue_iterator->subtasks_remaining.empty())
- {
- // all subtasks done == multitask is done
- task_result aggregate_result;
- aggregate_result.id = queue_iterator->id;
- aggregate_result.stop = true;
- aggregate_result.error = false;
-
- // collect json results into one json result
- std::vector<json> result_jsons;
- for (auto& subres : queue_iterator->results)
- {
- result_jsons.push_back(subres.result_json);
- aggregate_result.error = aggregate_result.error && subres.error;
- }
- aggregate_result.result_json = json{ "results", result_jsons };
-
-
- agg_results.push_back(aggregate_result);
-
- condition_results.notify_all();
-
- queue_iterator = queue_multitasks.erase(queue_iterator);
- }
- else
- {
- ++queue_iterator;
- }
+ result_jsons.push_back(subres.result_json);
+ result.error = result.error && subres.error;
}
-
- // done with tasks, unlock
- lock.unlock();
-
- // copy aggregate results of complete multi-tasks to the results queue
- std::lock_guard<std::mutex> lock_results(mutex_results);
- queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
+ result.result_json = json{ { "results", result_jsons } };
+ queue_results.send(result);
}
bool update_slots() {
- // attend tasks
- process_tasks();
-
if (system_need_update)
{
LOG_TEE("updating system prompt\n");
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
kv_cache_clear();
}
- std::unique_lock<std::mutex> lock(mutex_tasks);
- condition_tasks.wait(lock, [&]{
- return !queue_tasks.empty();
- });
+ return true;
+ } else {
+ task_server task;
+ task.type = TASK_TYPE_NEXT_RESPONSE;
+ task.target_id = -1;
+ queue_tasks.post(task);
}
for (llama_client_slot &slot : slots)
slot.t_last_used = ggml_time_us();
LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
+ queue_tasks.notify_slot_changed();
continue;
}
}
return true;
}
+
+ void run_on_all_tasks_finished() {
+ update_slots();
+ }
};
static void server_print_usage(const char *argv0, const gpt_params ¶ms,
}
}
-static std::string random_string()
-{
- static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
-
- std::random_device rd;
- std::mt19937 generator(rd());
-
- std::string result(32, ' ');
-
- for (int i = 0; i < 32; ++i) {
- result[i] = str[generator() % str.size()];
- }
-
- return result;
-}
-
-static std::string gen_chatcmplid()
-{
- std::stringstream chatcmplid;
- chatcmplid << "chatcmpl-" << random_string();
- return chatcmplid.str();
-}
-
-std::string format_chatml(std::vector<json> messages)
-{
- std::ostringstream chatml_msgs;
-
- for (auto it = messages.begin(); it != messages.end(); ++it) {
- chatml_msgs << "<|im_start|>"
- << json_value(*it, "role", std::string("user")) << '\n';
- chatml_msgs << json_value(*it, "content", std::string(""))
- << "<|im_end|>\n";
- }
-
- chatml_msgs << "<|im_start|>assistant" << '\n';
-
- return chatml_msgs.str();
-}
-
/* llama.cpp completion api semantics */
-json oaicompat_completion_params_parse(
- const json &body /* openai api json semantics */)
-{
- json llama_params;
-
- llama_params["__oaicompat"] = true;
-
- // Map OpenAI parameters to llama.cpp parameters
- //
- // For parameters that are defined by the OpenAI documentation (e.g.
- // temperature), we explicitly specify OpenAI's intended default; we
- // need to do that because sometimes OpenAI disagrees with llama.cpp
- //
- // https://platform.openai.com/docs/api-reference/chat/create
- llama_sampling_params default_sparams;
- llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
- llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
- llama_params["temperature"] = json_value(body, "temperature", 0.0);
- llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
- llama_params["top_p"] = json_value(body, "top_p", 1.0);
- llama_params["n_predict"] = json_value(body, "max_tokens", -1);
- llama_params["logit_bias"] = json_value(body, "logit_bias",json::object());
- llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
- llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
- llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
- llama_params["stream"] = json_value(body, "stream", false);
- llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
- llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
- llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
- llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
- llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
- llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
- llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
- llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
-
- if (body.count("grammar") != 0) {
- llama_params["grammar"] = json_value(body, "grammar", json::object());
- }
-
- // Handle 'stop' field
- if (body.contains("stop") && body["stop"].is_string()) {
- llama_params["stop"] = json::array({body["stop"].get<std::string>()});
- } else {
- llama_params["stop"] = json_value(body, "stop", json::array());
- }
-
- // Ensure there is ChatML-specific end sequence among stop words
- llama_params["stop"].push_back("<|im_end|>");
-
- return llama_params;
-}
-
-static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
-{
- json result = response.result_json;
-
- bool stopped_word = result.count("stopped_word") != 0;
- bool stopped_eos = json_value(result, "stopped_eos", false);
- int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
- int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
- std::string content = json_value(result, "content", std::string(""));
-
- std::string finish_reason = "length";
- if (stopped_word || stopped_eos) {
- finish_reason = "stop";
- }
-
- json choices =
- streaming ? json::array({json{{"finish_reason", finish_reason},
- {"index", 0},
- {"delta", json::object()}}})
- : json::array({json{{"finish_reason", finish_reason},
- {"index", 0},
- {"message", json{{"content", content},
- {"role", "assistant"}}}}});
-
- std::time_t t = std::time(0);
-
- json res =
- json{{"choices", choices},
- {"created", t},
- {"model",
- json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
- {"usage",
- json{{"completion_tokens", num_tokens_predicted},
- {"prompt_tokens", num_prompt_tokens},
- {"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
- {"id", gen_chatcmplid()}};
-
- if (server_verbose) {
- res["__verbose"] = result;
- }
-
- if (result.contains("completion_probabilities")) {
- res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
- }
-
- return res;
-}
-
-// return value is vector as there is one case where we might need to generate two responses
-static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
- json result = response.result_json;
-
- if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
- return std::vector<json>({response.result_json});
- }
-
- bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
- std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
-
- bool stopped_word = json_value(result, "stopped_word", false);
- bool stopped_eos = json_value(result, "stopped_eos", false);
- bool stopped_limit = json_value(result, "stopped_limit", false);
- std::string content = json_value(result, "content", std::string(""));
-
- std::string finish_reason;
- if (stopped_word || stopped_eos) {
- finish_reason = "stop";
- }
- if (stopped_limit) {
- finish_reason = "length";
- }
-
- std::time_t t = std::time(0);
-
- json choices;
-
- if (!finish_reason.empty()) {
- choices = json::array({json{{"finish_reason", finish_reason},
- {"index", 0},
- {"delta", json::object()}}});
- } else {
- if (first) {
- if (content.empty()) {
- choices = json::array({json{{"finish_reason", nullptr},
- {"index", 0},
- {"delta", json{{"role", "assistant"}}}}});
- } else {
- // We have to send this as two updates to conform to openai behavior
- json initial_ret = json{{"choices", json::array({json{
- {"finish_reason", nullptr},
- {"index", 0},
- {"delta", json{
- {"role", "assistant"}
- }}}})},
- {"created", t},
- {"id", gen_chatcmplid()},
- {"model", modelname},
- {"object", "chat.completion.chunk"}};
-
- json second_ret = json{
- {"choices", json::array({json{{"finish_reason", nullptr},
- {"index", 0},
- {"delta", json{
- {"content", content}}}
- }})},
- {"created", t},
- {"id", gen_chatcmplid()},
- {"model", modelname},
- {"object", "chat.completion.chunk"}};
-
- return std::vector<json>({initial_ret, second_ret});
- }
- } else {
- // Some idiosyncrasy in task processing logic makes several trailing calls
- // with empty content, we ignore these at the calee site.
- if (content.empty()) {
- return std::vector<json>({json::object()});
- }
-
- choices = json::array({json{
- {"finish_reason", nullptr},
- {"index", 0},
- {"delta",
- json{
- {"content", content},
- }},
- }});
- }
- }
-
- json ret = json{{"choices", choices},
- {"created", t},
- {"id", gen_chatcmplid()},
- {"model", modelname},
- {"object", "chat.completion.chunk"}};
-
- return std::vector<json>({ret});
-}
-
static json format_partial_response(
llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
) {
return;
}
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, false, false, -1);
+ const int task_id = llama.queue_tasks.get_new_id();
+ llama.queue_results.add_waiting_task_id(task_id);
+ llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
- task_result result = llama.next_result(task_id);
+ task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
}
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
- return;
}
+ llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
{
while (true)
{
- task_result result = llama.next_result(task_id);
+ task_result result = llama.queue_results.recv(task_id);
if (!result.error) {
const std::string str =
"data: " +
});
if (!sink.write(str.c_str(), str.size()))
{
+ llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
if (result.stop) {
});
if (!sink.write(str.c_str(), str.size()))
{
+ llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
+
+ llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
return true;
};
{
// cancel
llama.request_cancel(task_id);
+ llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
json data = oaicompat_completion_params_parse(json::parse(req.body));
- const int task_id = llama.request_completion(data, false, false, -1);
+ const int task_id = llama.queue_tasks.get_new_id();
+ llama.queue_results.add_waiting_task_id(task_id);
+ llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
- task_result result = llama.next_result(task_id);
+ task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
- return;
}
+ llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
- task_result llama_result = llama.next_result(task_id);
+ task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
+ llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
}
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
+ llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
+ llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
+ llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
return;
}
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, true, false, -1);
+ const int task_id = llama.queue_tasks.get_new_id();
+ llama.queue_results.add_waiting_task_id(task_id);
+ llama.request_completion(task_id, data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
- task_result result = llama.next_result(task_id);
+ task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop)
{
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
- return;
}
+ llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while (true)
{
- task_result result = llama.next_result(task_id);
+ task_result result = llama.queue_results.recv(task_id);
if (!result.error) {
const std::string str =
"data: " +
});
if (!sink.write(str.c_str(), str.size()))
{
+ llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
if (result.stop)
}
}
+ llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
-
return true;
};
image_data = "";
}
- const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
- task_result result = llama.next_result(task_id);
+ // create and queue the task
+ const int task_id = llama.queue_tasks.get_new_id();
+ llama.queue_results.add_waiting_task_id(task_id);
+ llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
+
+ // get the result
+ task_result result = llama.queue_results.recv(task_id);
+ llama.queue_results.remove_waiting_task_id(task_id);
+
+ // send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
- {
+ /*{
bool running = true;
while (running)
{
running = llama.update_slots();
}
- }
+ }*/
//);
+ llama.queue_tasks.on_new_task(std::bind(
+ &llama_server_context::process_single_task, &llama, std::placeholders::_1));
+ llama.queue_tasks.on_finish_multitask(std::bind(
+ &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
+ llama.queue_tasks.on_all_tasks_finished(std::bind(
+ &llama_server_context::run_on_all_tasks_finished, &llama));
+ llama.queue_results.on_multitask_update(std::bind(
+ &llama_server_queue::update_multitask,
+ &llama.queue_tasks,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3
+ ));
+ llama.queue_tasks.start_loop();
+
t.join();
llama_backend_free();
--- /dev/null
+#pragma once
+
+#include <string>
+#include <vector>
+#include <set>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+
+#include "json.hpp"
+
+#include "../llava/clip.h"
+
+using json = nlohmann::json;
+
+extern bool server_verbose;
+
+#ifndef SERVER_VERBOSE
+#define SERVER_VERBOSE 1
+#endif
+
+#if SERVER_VERBOSE != 1
+#define LOG_VERBOSE(MSG, ...)
+#else
+#define LOG_VERBOSE(MSG, ...) \
+ do \
+ { \
+ if (server_verbose) \
+ { \
+ server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
+ } \
+ } while (0)
+#endif
+
+#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
+
+//
+// parallel
+//
+
+enum server_state {
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
+ SERVER_STATE_READY, // Server is ready and model is loaded
+ SERVER_STATE_ERROR // An error occurred, load_model failed
+};
+
+enum task_type {
+ TASK_TYPE_COMPLETION,
+ TASK_TYPE_CANCEL,
+ TASK_TYPE_NEXT_RESPONSE
+};
+
+struct task_server {
+ int id = -1; // to be filled by llama_server_queue
+ int target_id;
+ task_type type;
+ json data;
+ bool infill_mode = false;
+ bool embedding_mode = false;
+ int multitask_id = -1;
+};
+
+struct task_result {
+ int id;
+ int multitask_id = -1;
+ bool stop;
+ bool error;
+ json result_json;
+};
+
+struct task_multi {
+ int id;
+ std::set<int> subtasks_remaining{};
+ std::vector<task_result> results{};
+};
+
+// TODO: can become bool if we can't find use of more states
+enum slot_state
+{
+ IDLE,
+ PROCESSING,
+};
+
+enum slot_command
+{
+ NONE,
+ LOAD_PROMPT,
+ RELEASE,
+};
+
+struct slot_params
+{
+ bool stream = true;
+ bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
+
+ uint32_t seed = -1; // RNG seed
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_predict = -1; // new tokens to predict
+
+ std::vector<std::string> antiprompt;
+
+ json input_prefix;
+ json input_suffix;
+};
+
+struct slot_image
+{
+ int32_t id;
+
+ bool request_encode_image = false;
+ float * image_embedding = nullptr;
+ int32_t image_tokens = 0;
+
+ clip_image_u8 * img_data;
+
+ std::string prefix_prompt; // before of this image
+};
+
+// completion token output with probabilities
+struct completion_token_output
+{
+ struct token_prob
+ {
+ llama_token tok;
+ float prob;
+ };
+
+ std::vector<token_prob> probs;
+ llama_token tok;
+ std::string text_to_send;
+};
+
+static inline void server_log(const char *level, const char *function, int line,
+ const char *message, const nlohmann::ordered_json &extra)
+{
+ nlohmann::ordered_json log
+ {
+ {"timestamp", time(nullptr)},
+ {"level", level},
+ {"function", function},
+ {"line", line},
+ {"message", message},
+ };
+
+ if (!extra.empty())
+ {
+ log.merge_patch(extra);
+ }
+
+ const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
+ printf("%.*s\n", (int)str.size(), str.data());
+ fflush(stdout);
+}
+
+//
+// server utils
+//
+
+template <typename T>
+static T json_value(const json &body, const std::string &key, const T &default_value)
+{
+ // Fallback null to default value
+ return body.contains(key) && !body.at(key).is_null()
+ ? body.value(key, default_value)
+ : default_value;
+}
+
+inline std::string format_chatml(std::vector<json> messages)
+{
+ std::ostringstream chatml_msgs;
+
+ for (auto it = messages.begin(); it != messages.end(); ++it) {
+ chatml_msgs << "<|im_start|>"
+ << json_value(*it, "role", std::string("user")) << '\n';
+ chatml_msgs << json_value(*it, "content", std::string(""))
+ << "<|im_end|>\n";
+ }
+
+ chatml_msgs << "<|im_start|>assistant" << '\n';
+
+ return chatml_msgs.str();
+}
+
+//
+// work queue utils
+//
+
+struct llama_server_queue {
+ int id = 0;
+ std::mutex mutex_tasks;
+ // queues
+ std::vector<task_server> queue_tasks;
+ std::vector<task_server> queue_tasks_deferred;
+ std::vector<task_multi> queue_multitasks;
+ std::condition_variable condition_tasks;
+ // callback functions
+ std::function<void(task_server&)> callback_new_task;
+ std::function<void(task_multi&)> callback_finish_multitask;
+ std::function<void(void)> callback_all_task_finished;
+
+ // Add a new task to the end of the queue
+ int post(task_server task) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (task.id == -1) {
+ task.id = id++;
+ }
+ queue_tasks.push_back(std::move(task));
+ condition_tasks.notify_one();
+ return task.id;
+ }
+
+ // Add a new task, but defer until one slot is available
+ void defer(task_server task) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ queue_tasks_deferred.push_back(std::move(task));
+ }
+
+ // Get the next id for creating anew task
+ int get_new_id() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ return id++;
+ }
+
+ // Register function to process a new task
+ void on_new_task(std::function<void(task_server&)> callback) {
+ callback_new_task = callback;
+ }
+
+ // Register function to process a multitask
+ void on_finish_multitask(std::function<void(task_multi&)> callback) {
+ callback_finish_multitask = callback;
+ }
+
+ // Register the function to be called when the batch of tasks is finished
+ void on_all_tasks_finished(std::function<void(void)> callback) {
+ callback_all_task_finished = callback;
+ }
+
+ // Call when the state of one slot is changed
+ void notify_slot_changed() {
+ // move deferred tasks back to main loop
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ for (auto & task : queue_tasks_deferred) {
+ queue_tasks.push_back(std::move(task));
+ }
+ queue_tasks_deferred.clear();
+ }
+
+ // Start the main loop. This call is blocking
+ void start_loop() {
+ while (true) {
+ // new task arrived
+ LOG_VERBOSE("have new task", {});
+ {
+ while (true)
+ {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (queue_tasks.empty()) {
+ lock.unlock();
+ break;
+ }
+ task_server task = queue_tasks.front();
+ queue_tasks.erase(queue_tasks.begin());
+ lock.unlock();
+ LOG_VERBOSE("callback_new_task", {});
+ callback_new_task(task);
+ }
+ LOG_VERBOSE("callback_all_task_finished", {});
+ // process and update all the multitasks
+ auto queue_iterator = queue_multitasks.begin();
+ while (queue_iterator != queue_multitasks.end())
+ {
+ if (queue_iterator->subtasks_remaining.empty())
+ {
+ // all subtasks done == multitask is done
+ task_multi current_multitask = *queue_iterator;
+ callback_finish_multitask(current_multitask);
+ // remove this multitask
+ queue_iterator = queue_multitasks.erase(queue_iterator);
+ }
+ else
+ {
+ ++queue_iterator;
+ }
+ }
+ // all tasks in the current loop is finished
+ callback_all_task_finished();
+ }
+ LOG_VERBOSE("wait for new task", {});
+ // wait for new task
+ {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (queue_tasks.empty()) {
+ condition_tasks.wait(lock, [&]{
+ return !queue_tasks.empty();
+ });
+ }
+ }
+ }
+ }
+
+ //
+ // functions to manage multitasks
+ //
+
+ // add a multitask by specifying the id of all subtask (subtask is a task_server)
+ void add_multitask(int multitask_id, std::vector<int>& sub_ids)
+ {
+ std::lock_guard<std::mutex> lock(mutex_tasks);
+ task_multi multi;
+ multi.id = multitask_id;
+ std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
+ queue_multitasks.push_back(multi);
+ }
+
+ // updatethe remaining subtasks, while appending results to multitask
+ void update_multitask(int multitask_id, int subtask_id, task_result& result)
+ {
+ std::lock_guard<std::mutex> lock(mutex_tasks);
+ for (auto& multitask : queue_multitasks)
+ {
+ if (multitask.id == multitask_id)
+ {
+ multitask.subtasks_remaining.erase(subtask_id);
+ multitask.results.push_back(result);
+ }
+ }
+ }
+};
+
+struct llama_server_response {
+ typedef std::function<void(int, int, task_result&)> callback_multitask_t;
+ callback_multitask_t callback_update_multitask;
+ // for keeping track of all tasks waiting for the result
+ std::set<int> waiting_task_ids;
+ // the main result queue
+ std::vector<task_result> queue_results;
+ std::mutex mutex_results;
+ std::condition_variable condition_results;
+
+ void add_waiting_task_id(int task_id) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.insert(task_id);
+ }
+
+ void remove_waiting_task_id(int task_id) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.erase(task_id);
+ }
+
+ // This function blocks the thread until there is a response for this task_id
+ task_result recv(int task_id) {
+ while (true)
+ {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ condition_results.wait(lock, [&]{
+ return !queue_results.empty();
+ });
+ LOG_VERBOSE("condition_results unblock", {});
+
+ for (int i = 0; i < (int) queue_results.size(); i++)
+ {
+ if (queue_results[i].id == task_id)
+ {
+ assert(queue_results[i].multitask_id == -1);
+ task_result res = queue_results[i];
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+ }
+
+ // should never reach here
+ }
+
+ // Register the function to update multitask
+ void on_multitask_update(callback_multitask_t callback) {
+ callback_update_multitask = callback;
+ }
+
+ // Send a new result to a waiting task_id
+ void send(task_result result) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ LOG_VERBOSE("send new result", {});
+ for (auto& task_id : waiting_task_ids) {
+ // LOG_TEE("waiting task id %i \n", task_id);
+ // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
+ if (result.multitask_id == task_id)
+ {
+ LOG_VERBOSE("callback_update_multitask", {});
+ callback_update_multitask(task_id, result.id, result);
+ continue;
+ }
+
+ if (result.id == task_id)
+ {
+ LOG_VERBOSE("queue_results.push_back", {});
+ queue_results.push_back(result);
+ condition_results.notify_one();
+ return;
+ }
+ }
+ }
+};
+
+//
+// base64 utils (TODO: move to common in the future)
+//
+
+static const std::string base64_chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
+
+static inline bool is_base64(uint8_t c)
+{
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
+
+static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
+{
+ int i = 0;
+ int j = 0;
+ int in_ = 0;
+
+ int in_len = encoded_string.size();
+
+ uint8_t char_array_4[4];
+ uint8_t char_array_3[3];
+
+ std::vector<uint8_t> ret;
+
+ while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
+ {
+ char_array_4[i++] = encoded_string[in_]; in_++;
+ if (i == 4)
+ {
+ for (i = 0; i <4; i++)
+ {
+ char_array_4[i] = base64_chars.find(char_array_4[i]);
+ }
+
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (i = 0; (i < 3); i++)
+ {
+ ret.push_back(char_array_3[i]);
+ }
+ i = 0;
+ }
+ }
+
+ if (i)
+ {
+ for (j = i; j <4; j++)
+ {
+ char_array_4[j] = 0;
+ }
+
+ for (j = 0; j <4; j++)
+ {
+ char_array_4[j] = base64_chars.find(char_array_4[j]);
+ }
+
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (j = 0; (j < i - 1); j++)
+ {
+ ret.push_back(char_array_3[j]);
+ }
+ }
+
+ return ret;
+}
+
+//
+// random string / id
+//
+
+static std::string random_string()
+{
+ static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
+
+ std::random_device rd;
+ std::mt19937 generator(rd());
+
+ std::string result(32, ' ');
+
+ for (int i = 0; i < 32; ++i) {
+ result[i] = str[generator() % str.size()];
+ }
+
+ return result;
+}
+
+static std::string gen_chatcmplid()
+{
+ std::stringstream chatcmplid;
+ chatcmplid << "chatcmpl-" << random_string();
+ return chatcmplid.str();
+}