#include "chat.h"
#include "utils.hpp"
+#include "server-http.h"
#include "arg.h"
#include "common.h"
#include "speculative.h"
#include "mtmd.h"
-// mime type for sending response
-#define MIMETYPE_JSON "application/json; charset=utf-8"
-
-// auto generated files (see README.md for details)
-#include "index.html.gz.hpp"
-#include "loading.html.hpp"
-
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <deque>
#include <memory>
#include <mutex>
+#include <list>
#include <signal.h>
#include <thread>
-#include <unordered_map>
#include <unordered_set>
+// fix problem with std::min and std::max
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+# define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
server_prompt prompt;
void prompt_save(server_prompt_cache & prompt_cache) const {
- assert(prompt.data.size() == 0);
+ GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
llama_batch_free(batch);
}
+ // load the model and initialize llama_context
bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.path.c_str());
return true;
}
+ // initialize slots and server-related data
void init() {
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
};
+
+ // print sample chat example to make it clear which template is used
+ LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
+ common_chat_templates_source(chat_templates.get()),
+ common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
}
server_slot * get_slot_by_id(int id) {
}
};
+
// generator-like API for server responses, support pooling connection state and aggregating results
struct server_response_reader {
std::unordered_set<int> id_tasks;
ctx_server.queue_tasks.post(std::move(tasks));
}
- bool has_next() {
+ bool has_next() const {
return !cancelled && received_count < id_tasks.size();
}
}
};
-static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
- // skip GH copilot requests when using default port
- if (req.path == "/v1/health") {
- return;
- }
-
- // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
-
- SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
-
- SRV_DBG("request: %s\n", req.body.c_str());
- SRV_DBG("response: %s\n", res.body.c_str());
-}
-
-static void res_err(httplib::Response & res, const json & error_data) {
- json final_response {{"error", error_data}};
- res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
- res.status = json_value(error_data, "code", 500);
-}
-
-static void res_ok(httplib::Response & res, const json & data) {
- res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
- res.status = 200;
-}
-
-std::function<void(int)> shutdown_handler;
-std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
-
-inline void signal_handler(int signal) {
- if (is_terminating.test_and_set()) {
- // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
- // this is for better developer experience, we can remove when the server is stable enough
- fprintf(stderr, "Received second interrupt, terminating immediately.\n");
- exit(1);
- }
-
- shutdown_handler(signal);
-}
-
-int main(int argc, char ** argv) {
- // own arguments required by this example
- common_params params;
-
- if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
- return 1;
- }
-
- // TODO: should we have a separate n_parallel parameter for the server?
- // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
- // TODO: this is a common configuration that is suitable for most local use cases
- // however, overriding the parameters is a bit confusing - figure out something more intuitive
- if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
- LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
-
- params.n_parallel = 4;
- params.kv_unified = true;
- }
-
- common_init();
-
- // struct that contains llama context and inference
- server_context ctx_server;
-
- llama_backend_init();
- llama_numa_init(params.numa);
-
- LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
- LOG_INF("\n");
- LOG_INF("%s\n", common_params_get_system_info(params).c_str());
- LOG_INF("\n");
-
- std::unique_ptr<httplib::Server> svr;
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
- if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
- LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
- svr.reset(
- new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
- );
- } else {
- LOG_INF("Running without SSL\n");
- svr.reset(new httplib::Server());
+// generator-like API for HTTP response generation
+struct server_res_generator : server_http_res {
+ server_response_reader rd;
+ server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {}
+ void ok(const json & response_data) {
+ status = 200;
+ data = safe_json_to_str(response_data);
}
-#else
- if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
- LOG_ERR("Server is built without SSL support\n");
- return 1;
- }
- svr.reset(new httplib::Server());
-#endif
-
- std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
-
- svr->set_default_headers({{"Server", "llama.cpp"}});
- svr->set_logger(log_server_request);
- svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
- std::string message;
- try {
- std::rethrow_exception(ep);
- } catch (const std::exception & e) {
- message = e.what();
- } catch (...) {
- message = "Unknown Exception";
- }
-
- try {
- json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
- LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
- res_err(res, formatted_error);
- } catch (const std::exception & e) {
- LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
- }
- });
-
- svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
- if (res.status == 404) {
- res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
- }
- // for other error codes, we skip processing here because it's already done by res_err()
- });
-
- // set timeouts and change hostname and port
- svr->set_read_timeout (params.timeout_read);
- svr->set_write_timeout(params.timeout_write);
-
- std::unordered_map<std::string, std::string> log_data;
-
- log_data["hostname"] = params.hostname;
- log_data["port"] = std::to_string(params.port);
-
- if (params.api_keys.size() == 1) {
- auto key = params.api_keys[0];
- log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
- } else if (params.api_keys.size() > 1) {
- log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
+ void error(const json & error_data) {
+ status = json_value(error_data, "code", 500);
+ data = safe_json_to_str({{ "error", error_data }});
}
+};
- // Necessary similarity of prompt for slot selection
- ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
-
- //
- // Middlewares
- //
-
- auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) {
- static const std::unordered_set<std::string> public_endpoints = {
- "/health",
- "/v1/health",
- "/models",
- "/v1/models",
- "/api/tags"
- };
-
- // If API key is not set, skip validation
- if (params.api_keys.empty()) {
- return true;
- }
-
- // If path is public or is static file, skip validation
- if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
- return true;
- }
-
- // Check for API key in the header
- auto auth_header = req.get_header_value("Authorization");
-
- std::string prefix = "Bearer ";
- if (auth_header.substr(0, prefix.size()) == prefix) {
- std::string received_api_key = auth_header.substr(prefix.size());
- if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) {
- return true; // API key is valid
- }
- }
-
- // API key is invalid or not provided
- res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
-
- LOG_WRN("Unauthorized: Invalid API Key\n");
-
- return false;
- };
-
- auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) {
- server_state current_state = state.load();
- if (current_state == SERVER_STATE_LOADING_MODEL) {
- auto tmp = string_split<std::string>(req.path, '.');
- if (req.path == "/" || tmp.back() == "html") {
- res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
- res.status = 503;
- } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
- // allow the models endpoint to be accessed during loading
- return true;
- } else {
- res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
- }
- return false;
- }
- return true;
- };
-
- // register server middlewares
- svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- // If this is OPTIONS request, skip validation because browsers don't include Authorization header
- if (req.method == "OPTIONS") {
- res.set_header("Access-Control-Allow-Credentials", "true");
- res.set_header("Access-Control-Allow-Methods", "GET, POST");
- res.set_header("Access-Control-Allow-Headers", "*");
- res.set_content("", "text/html"); // blank response, no data
- return httplib::Server::HandlerResponse::Handled; // skip further processing
- }
- if (!middleware_server_state(req, res)) {
- return httplib::Server::HandlerResponse::Handled;
- }
- if (!middleware_validate_api_key(req, res)) {
- return httplib::Server::HandlerResponse::Handled;
- }
- return httplib::Server::HandlerResponse::Unhandled;
- });
+struct server_routes {
+ const common_params & params;
+ server_context & ctx_server;
+ server_http_context & ctx_http; // for reading is_ready
+ server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http)
+ : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {}
- //
- // Route handlers (or controllers)
- //
+public:
+ // handlers using lambda function, so that they can capture `this` without `std::bind`
- const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
+ server_http_context::handler_t get_health = [this](const server_http_req &) {
// error and loading states are handled by middleware
- json health = {{"status", "ok"}};
- res_ok(res, health);
- };
-
- const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
- if (!params.endpoint_slots) {
- res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
- return;
- }
-
- // request slots data using task queue
- int task_id = ctx_server.queue_tasks.get_new_id();
- {
- server_task task(SERVER_TASK_TYPE_METRICS);
- task.id = task_id;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
- }
-
- // get the result
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
-
- if (result->is_error()) {
- res_err(res, result->to_json());
- return;
- }
-
- // TODO: get rid of this dynamic_cast
- auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
- GGML_ASSERT(res_task != nullptr);
-
- // optionally return "fail_on_no_slot" error
- if (req.has_param("fail_on_no_slot")) {
- if (res_task->n_idle_slots == 0) {
- res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
- return;
- }
- }
-
- res_ok(res, res_task->slots_data);
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ res->ok({{"status", "ok"}});
+ return res;
};
- const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
+ server_http_context::handler_t get_metrics = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
if (!params.endpoint_metrics) {
- res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
- return;
+ res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
// request slots data using task queue
+ // TODO: use server_response_reader
int task_id = ctx_server.queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_METRICS);
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
- res_err(res, result->to_json());
- return;
+ res->error(result->to_json());
+ return res;
}
// TODO: get rid of this dynamic_cast
}
}
- res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start));
-
- res.set_content(prometheus.str(), "text/plain; version=0.0.4");
- res.status = 200; // HTTP OK
- };
-
- const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
- json request_data = json::parse(req.body);
- std::string filename = request_data.at("filename");
- if (!fs_validate_filename(filename)) {
- res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
- return;
- }
- std::string filepath = params.slot_save_path + filename;
-
- int task_id = ctx_server.queue_tasks.get_new_id();
- {
- server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
- task.id = task_id;
- task.slot_action.slot_id = id_slot;
- task.slot_action.filename = filename;
- task.slot_action.filepath = filepath;
-
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
- }
-
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
-
- if (result->is_error()) {
- res_err(res, result->to_json());
- return;
- }
-
- res_ok(res, result->to_json());
+ res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
+ res->content_type = "text/plain; version=0.0.4";
+ res->ok(prometheus.str());
+ return res;
};
- const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
- json request_data = json::parse(req.body);
- std::string filename = request_data.at("filename");
- if (!fs_validate_filename(filename)) {
- res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
- return;
+ server_http_context::handler_t get_slots = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ if (!params.endpoint_slots) {
+ res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
- std::string filepath = params.slot_save_path + filename;
+ // request slots data using task queue
int task_id = ctx_server.queue_tasks.get_new_id();
{
- server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+ server_task task(SERVER_TASK_TYPE_METRICS);
task.id = task_id;
- task.slot_action.slot_id = id_slot;
- task.slot_action.filename = filename;
- task.slot_action.filepath = filepath;
-
ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
}
+ // get the result
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
- res_err(res, result->to_json());
- return;
- }
-
- GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
- res_ok(res, result->to_json());
- };
-
- const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
- int task_id = ctx_server.queue_tasks.get_new_id();
- {
- server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
- task.id = task_id;
- task.slot_action.slot_id = id_slot;
-
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ res->error(result->to_json());
+ return res;
}
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ // TODO: get rid of this dynamic_cast
+ auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
+ GGML_ASSERT(res_task != nullptr);
- if (result->is_error()) {
- res_err(res, result->to_json());
- return;
+ // optionally return "fail_on_no_slot" error
+ if (!req.get_param("fail_on_no_slot").empty()) {
+ if (res_task->n_idle_slots == 0) {
+ res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
+ return res;
+ }
}
- GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
- res_ok(res, result->to_json());
+ res->ok(res_task->slots_data);
+ return res;
};
- const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_slots = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
if (params.slot_save_path.empty()) {
- res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
- return;
+ res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
- std::string id_slot_str = req.path_params.at("id_slot");
+ std::string id_slot_str = req.get_param("id_slot");
int id_slot;
try {
id_slot = std::stoi(id_slot_str);
} catch (const std::exception &) {
- res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
- std::string action = req.get_param_value("action");
+ std::string action = req.get_param("action");
if (action == "save") {
- handle_slots_save(req, res, id_slot);
+ return handle_slots_save(req, id_slot);
} else if (action == "restore") {
- handle_slots_restore(req, res, id_slot);
+ return handle_slots_restore(req, id_slot);
} else if (action == "erase") {
- handle_slots_erase(req, res, id_slot);
+ return handle_slots_erase(req, id_slot);
} else {
- res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
+ res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
};
- const auto handle_props = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
+ server_http_context::handler_t get_props = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
json default_generation_settings_for_props;
{
}
}
- res_ok(res, data);
+ res->ok(data);
+ return res;
};
- const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
- if (!ctx_server.params_base.endpoint_props) {
- res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
- return;
+ server_http_context::handler_t post_props = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ if (!params.endpoint_props) {
+ res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
-
- json data = json::parse(req.body);
-
// update any props here
- res_ok(res, {{ "success", true }});
+ res->ok({{ "success", true }});
+ return res;
};
- const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) {
+ server_http_context::handler_t get_api_show = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
bool has_mtmd = ctx_server.mctx != nullptr;
json data = {
{
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
};
- res_ok(res, data);
+ res->ok(data);
+ return res;
};
- // handle completion-like requests (completion, chat, infill)
- // we can optionally provide a custom format for partial results and final results
- const auto handle_completions_impl = [&ctx_server](
- server_task_type type,
- json & data,
- const std::vector<raw_buffer> & files,
- const std::function<bool()> & is_connection_closed,
- httplib::Response & res,
- oaicompat_type oaicompat) -> void {
- GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
-
- auto completion_id = gen_chatcmplid();
- // need to store the reader as a pointer, so that it won't be destroyed when the handle returns
- // use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
- const auto rd = std::make_shared<server_response_reader>(ctx_server);
-
- try {
- std::vector<server_task> tasks;
-
- const auto & prompt = data.at("prompt");
- // TODO: this log can become very long, put it behind a flag or think about a more compact format
- //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
-
- // process prompt
- std::vector<server_tokens> inputs;
-
- if (oaicompat && ctx_server.mctx != nullptr) {
- // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
- inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
- } else {
- // Everything else, including multimodal completions.
- inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
- }
- tasks.reserve(inputs.size());
- for (size_t i = 0; i < inputs.size(); i++) {
- server_task task = server_task(type);
-
- task.id = ctx_server.queue_tasks.get_new_id();
- task.index = i;
-
- task.tokens = std::move(inputs[i]);
- task.params = server_task::params_from_json_cmpl(
- ctx_server.ctx,
- ctx_server.params_base,
- data);
- task.id_slot = json_value(data, "id_slot", -1);
-
- // OAI-compat
- task.params.oaicompat = oaicompat;
- task.params.oaicompat_cmpl_id = completion_id;
- // oaicompat_model is already populated by params_from_json_cmpl
-
- tasks.push_back(std::move(task));
- }
-
- rd->post_tasks(std::move(tasks));
- } catch (const std::exception & e) {
- res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
- return;
- }
-
- bool stream = json_value(data, "stream", false);
-
- if (!stream) {
- // non-stream, wait for the results
- auto all_results = rd->wait_for_all(is_connection_closed);
- if (all_results.is_terminated) {
- return; // connection is closed
- } else if (all_results.error) {
- res_err(res, all_results.error->to_json());
- return;
- } else {
- json arr = json::array();
- for (auto & res : all_results.results) {
- GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
- arr.push_back(res->to_json());
- }
- // if single request, return single object instead of array
- res_ok(res, arr.size() == 1 ? arr[0] : arr);
- }
-
- } else {
- // in streaming mode, the first error must be treated as non-stream response
- // this is to match the OAI API behavior
- // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
- server_task_result_ptr first_result = rd->next(is_connection_closed);
- if (first_result == nullptr) {
- return; // connection is closed
- } else if (first_result->is_error()) {
- res_err(res, first_result->to_json());
- return;
- } else {
- GGML_ASSERT(
- dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
- || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
- );
- }
-
- // next responses are streamed
- json first_result_json = first_result->to_json();
- const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool {
- // flush the first result as it's not an error
- if (!first_result_json.empty()) {
- if (!server_sent_event(sink, first_result_json)) {
- sink.done();
- return false; // sending failed, go to on_complete()
- }
- first_result_json.clear(); // mark as sent
- }
-
- // receive subsequent results
- auto result = rd->next([&sink]{ return !sink.is_writable(); });
- if (result == nullptr) {
- sink.done();
- return false; // connection is closed, go to on_complete()
- }
-
- // send the results
- json res_json = result->to_json();
- bool ok = false;
- if (result->is_error()) {
- ok = server_sent_event(sink, json {{ "error", result->to_json() }});
- sink.done();
- return false; // go to on_complete()
- } else {
- GGML_ASSERT(
- dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
- || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
- );
- ok = server_sent_event(sink, res_json);
- }
-
- if (!ok) {
- sink.done();
- return false; // sending failed, go to on_complete()
- }
-
- // check if there is more data
- if (!rd->has_next()) {
- if (oaicompat != OAICOMPAT_TYPE_NONE) {
- static const std::string ev_done = "data: [DONE]\n\n";
- sink.write(ev_done.data(), ev_done.size());
- }
- sink.done();
- return false; // no more data, go to on_complete()
- }
-
- // has next data, continue
- return true;
- };
-
- auto on_complete = [rd](bool) {
- rd->stop();
- };
-
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
- }
- };
-
- const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
- json data = json::parse(req.body);
- std::vector<raw_buffer> files; // dummy
- handle_completions_impl(
- SERVER_TASK_TYPE_COMPLETION,
- data,
- files,
- req.is_connection_closed,
- res,
- OAICOMPAT_TYPE_NONE);
- };
-
- const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
- json data = oaicompat_completion_params_parse(json::parse(req.body));
- std::vector<raw_buffer> files; // dummy
- handle_completions_impl(
- SERVER_TASK_TYPE_COMPLETION,
- data,
- files,
- req.is_connection_closed,
- res,
- OAICOMPAT_TYPE_COMPLETION);
- };
-
- const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_infill = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
// check model compatibility
std::string err;
if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. ";
}
if (!err.empty()) {
- res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
- return;
+ res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
- json data = json::parse(req.body);
-
// validate input
+ json data = json::parse(req.body);
if (data.contains("prompt") && !data.at("prompt").is_string()) {
// prompt is optional
- res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+ res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_prefix")) {
- res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
+ res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_suffix")) {
- res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
+ res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
// input_extra is optional
- res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
- res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
// filename is optional
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
- res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
);
std::vector<raw_buffer> files; // dummy
- handle_completions_impl(
+ return handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
files,
- req.is_connection_closed,
- res,
+ req.should_stop,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
- const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
- LOG_DBG("request: %s\n", req.body.c_str());
+ server_http_context::handler_t post_completions = [this](const server_http_req & req) {
+ std::vector<raw_buffer> files; // dummy
+ const json body = json::parse(req.body);
+ return handle_completions_impl(
+ SERVER_TASK_TYPE_COMPLETION,
+ body,
+ files,
+ req.should_stop,
+ OAICOMPAT_TYPE_NONE);
+ };
+
+ server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
+ std::vector<raw_buffer> files; // dummy
+ const json body = json::parse(req.body);
+ return handle_completions_impl(
+ SERVER_TASK_TYPE_COMPLETION,
+ body,
+ files,
+ req.should_stop,
+ OAICOMPAT_TYPE_COMPLETION);
+ };
- auto body = json::parse(req.body);
+ server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
std::vector<raw_buffer> files;
- json data = oaicompat_chat_params_parse(
+ json body = json::parse(req.body);
+ json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
files);
-
- handle_completions_impl(
+ return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
- data,
+ body_parsed,
files,
- req.is_connection_closed,
- res,
+ req.should_stop,
OAICOMPAT_TYPE_CHAT);
};
// same with handle_chat_completions, but without inference part
- const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
- auto body = json::parse(req.body);
+ server_http_context::handler_t post_apply_template = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files; // dummy, unused
+ json body = json::parse(req.body);
json data = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
files);
- res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
+ res->ok({{ "prompt", std::move(data.at("prompt")) }});
+ return res;
};
- const auto handle_models = [¶ms, &ctx_server, &state](const httplib::Request &, httplib::Response & res) {
- server_state current_state = state.load();
+ server_http_context::handler_t get_models = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ bool is_model_ready = ctx_http.is_ready.load();
json model_meta = nullptr;
- if (current_state == SERVER_STATE_READY) {
+ if (is_model_ready) {
model_meta = ctx_server.model_meta();
}
bool has_mtmd = ctx_server.mctx != nullptr;
}}
};
- res_ok(res, models);
+ res->ok(models);
+ return res;
};
- const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_tokenize = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
const json body = json::parse(req.body);
-
json tokens_response = json::array();
if (body.count("content") != 0) {
const bool add_special = json_value(body, "add_special", false);
}
const json data = format_tokenizer_response(tokens_response);
- res_ok(res, data);
+ res->ok(data);
+ return res;
};
- const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_detokenize = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
const json body = json::parse(req.body);
std::string content;
}
const json data = format_detokenized_response(content);
- res_ok(res, data);
- };
-
- const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
- if (!ctx_server.params_base.embedding) {
- res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
- return;
- }
-
- if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
- res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
- return;
- }
-
- const json body = json::parse(req.body);
-
- // for the shape of input/content, see tokenize_input_prompts()
- json prompt;
- if (body.count("input") != 0) {
- prompt = body.at("input");
- } else if (body.contains("content")) {
- oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
- prompt = body.at("content");
- } else {
- res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
- return;
- }
-
- bool use_base64 = false;
- if (body.count("encoding_format") != 0) {
- const std::string& format = body.at("encoding_format");
- if (format == "base64") {
- use_base64 = true;
- } else if (format != "float") {
- res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
- return;
- }
- }
-
- auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
- for (const auto & tokens : tokenized_prompts) {
- // this check is necessary for models that do not add BOS token to the input
- if (tokens.empty()) {
- res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
- return;
- }
- }
-
- int embd_normalize = 2; // default to Euclidean/L2 norm
- if (body.count("embd_normalize") != 0) {
- embd_normalize = body.at("embd_normalize");
- if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
- SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
- }
- }
-
- // create and queue the task
- json responses = json::array();
- server_response_reader rd(ctx_server);
- {
- std::vector<server_task> tasks;
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
- server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
-
- task.id = ctx_server.queue_tasks.get_new_id();
- task.index = i;
- task.tokens = std::move(tokenized_prompts[i]);
-
- // OAI-compat
- task.params.oaicompat = oaicompat;
- task.params.embd_normalize = embd_normalize;
-
- tasks.push_back(std::move(task));
- }
- rd.post_tasks(std::move(tasks));
- }
-
- // wait for the results
- auto all_results = rd.wait_for_all(req.is_connection_closed);
-
- // collect results
- if (all_results.is_terminated) {
- return; // connection is closed
- } else if (all_results.error) {
- res_err(res, all_results.error->to_json());
- return;
- } else {
- for (auto & res : all_results.results) {
- GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
- responses.push_back(res->to_json());
- }
- }
-
- // write JSON response
- json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
- ? format_embeddings_response_oaicompat(body, responses, use_base64)
- : json(responses);
- res_ok(res, root);
+ res->ok(data);
+ return res;
};
- const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
- handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
+ server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
+ return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
};
- const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
- handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
+ server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
+ return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
};
- const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
- res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
- return;
+ res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
}
const json body = json::parse(req.body);
if (body.count("query") == 1) {
query = body.at("query");
if (!query.is_string()) {
- res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
} else {
- res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) {
- res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
int top_n = json_value(body, "top_n", (int)documents.size());
}
// wait for the results
- auto all_results = rd.wait_for_all(req.is_connection_closed);
+ auto all_results = rd.wait_for_all(req.should_stop);
// collect results
if (all_results.is_terminated) {
- return; // connection is closed
+ return res; // connection is closed
} else if (all_results.error) {
- res_err(res, all_results.error->to_json());
- return;
+ res->error(all_results.error->to_json());
+ return res;
} else {
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
documents,
top_n);
- res_ok(res, root);
+ res->ok(root);
+ return res;
};
- const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
+ server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
json result = json::array();
const auto & loras = ctx_server.params_base.lora_adapters;
for (size_t i = 0; i < loras.size(); ++i) {
}
result.push_back(std::move(entry));
}
- res_ok(res, result);
- res.status = 200; // HTTP OK
+ res->ok(result);
+ return res;
};
- const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
+ server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
const json body = json::parse(req.body);
if (!body.is_array()) {
- res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
- return;
+ res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
+ return res;
}
int task_id = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
- res_err(res, result->to_json());
- return;
+ res->error(result->to_json());
+ return res;
}
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
- res_ok(res, result->to_json());
+ res->ok(result->to_json());
+ return res;
};
- //
- // Router
- //
+private:
+ std::unique_ptr<server_res_generator> handle_completions_impl(
+ server_task_type type,
+ const json & data,
+ const std::vector<raw_buffer> & files,
+ const std::function<bool()> & should_stop,
+ oaicompat_type oaicompat) {
+ GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
+
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ auto completion_id = gen_chatcmplid();
+ auto & rd = res->rd;
- if (!params.webui) {
- LOG_INF("Web UI is disabled\n");
- } else {
- // register static assets routes
- if (!params.public_path.empty()) {
- // Set the base directory for serving static files
- bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
- if (!is_found) {
- LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
- return 1;
+ try {
+ std::vector<server_task> tasks;
+
+ const auto & prompt = data.at("prompt");
+ // TODO: this log can become very long, put it behind a flag or think about a more compact format
+ //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
+
+ // process prompt
+ std::vector<server_tokens> inputs;
+
+ if (oaicompat && ctx_server.mctx != nullptr) {
+ // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
+ inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
+ } else {
+ // Everything else, including multimodal completions.
+ inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
+ tasks.reserve(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ server_task task = server_task(type);
+
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+
+ task.tokens = std::move(inputs[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server.ctx,
+ ctx_server.params_base,
+ data);
+ task.id_slot = json_value(data, "id_slot", -1);
+
+ // OAI-compat
+ task.params.oaicompat = oaicompat;
+ task.params.oaicompat_cmpl_id = completion_id;
+ // oaicompat_model is already populated by params_from_json_cmpl
+
+ tasks.push_back(std::move(task));
+ }
+
+ rd.post_tasks(std::move(tasks));
+ } catch (const std::exception & e) {
+ res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+
+ bool stream = json_value(data, "stream", false);
+
+ if (!stream) {
+ // non-stream, wait for the results
+ auto all_results = rd.wait_for_all(should_stop);
+ if (all_results.is_terminated) {
+ return res; // connection is closed
+ } else if (all_results.error) {
+ res->error(all_results.error->to_json());
+ return res;
+ } else {
+ json arr = json::array();
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
+ arr.push_back(res->to_json());
+ }
+ // if single request, return single object instead of array
+ res->ok(arr.size() == 1 ? arr[0] : arr);
+ }
+
} else {
- // using embedded static index.html
- svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
- if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
- res.set_content("Error: gzip is not supported by this browser", "text/plain");
+ // in streaming mode, the first error must be treated as non-stream response
+ // this is to match the OAI API behavior
+ // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
+ server_task_result_ptr first_result = rd.next(should_stop);
+ if (first_result == nullptr) {
+ return res; // connection is closed
+ } else if (first_result->is_error()) {
+ res->error(first_result->to_json());
+ return res;
+ } else {
+ GGML_ASSERT(
+ dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
+ || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
+ );
+ }
+
+ // next responses are streamed
+ res->data = format_sse(first_result->to_json()); // to be sent immediately
+ res->status = 200;
+ res->content_type = "text/event-stream";
+ res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
+ if (should_stop()) {
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+ return false; // should_stop condition met
+ }
+
+ if (!res_this->data.empty()) {
+ // flush the first chunk
+ output = std::move(res_this->data);
+ res_this->data.clear();
+ return true;
+ }
+
+ server_response_reader & rd = res_this->rd;
+
+ // check if there is more data
+ if (!rd.has_next()) {
+ if (oaicompat != OAICOMPAT_TYPE_NONE) {
+ output = "data: [DONE]\n\n";
+ } else {
+ output = "";
+ }
+ SRV_DBG("%s", "all results received, terminating stream\n");
+ return false; // no more data, terminate
+ }
+
+ // receive subsequent results
+ auto result = rd.next(should_stop);
+ if (result == nullptr) {
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+ return false; // should_stop condition met
+ }
+
+ // send the results
+ json res_json = result->to_json();
+ if (result->is_error()) {
+ output = format_sse(json {{ "error", res_json }});
+ SRV_DBG("%s", "error received during streaming, terminating stream\n");
+ return false; // terminate on error
} else {
- res.set_header("Content-Encoding", "gzip");
- // COEP and COOP headers, required by pyodide (python interpreter)
- res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
- res.set_header("Cross-Origin-Opener-Policy", "same-origin");
- res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
+ GGML_ASSERT(
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+ );
+ output = format_sse(res_json);
}
- return false;
- });
+
+ // has next data, continue
+ return true;
+ };
+ }
+
+ return res;
+ }
+
+ std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ const json request_data = json::parse(req.body);
+ std::string filename = request_data.at("filename");
+ if (!fs_validate_filename(filename)) {
+ res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+ std::string filepath = params.slot_save_path + filename;
+
+ int task_id = ctx_server.queue_tasks.get_new_id();
+ {
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+ task.id = task_id;
+ task.slot_action.slot_id = id_slot;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ // TODO: use server_response_reader
+ ctx_server.queue_results.add_waiting_task_id(task_id);
+ ctx_server.queue_tasks.post(std::move(task));
+ }
+
+ server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+ ctx_server.queue_results.remove_waiting_task_id(task_id);
+
+ if (result->is_error()) {
+ res->error(result->to_json());
+ return res;
+ }
+
+ res->ok(result->to_json());
+ return res;
+ }
+
+ std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ const json request_data = json::parse(req.body);
+ std::string filename = request_data.at("filename");
+ if (!fs_validate_filename(filename)) {
+ res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+ std::string filepath = params.slot_save_path + filename;
+
+ int task_id = ctx_server.queue_tasks.get_new_id();
+ {
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+ task.id = task_id;
+ task.slot_action.slot_id = id_slot;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ // TODO: use server_response_reader
+ ctx_server.queue_results.add_waiting_task_id(task_id);
+ ctx_server.queue_tasks.post(std::move(task));
+ }
+
+ server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+ ctx_server.queue_results.remove_waiting_task_id(task_id);
+
+ if (result->is_error()) {
+ res->error(result->to_json());
+ return res;
+ }
+
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
+ res->ok(result->to_json());
+ return res;
+ }
+
+ std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ int task_id = ctx_server.queue_tasks.get_new_id();
+ {
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+ task.id = task_id;
+ task.slot_action.slot_id = id_slot;
+
+ // TODO: use server_response_reader
+ ctx_server.queue_results.add_waiting_task_id(task_id);
+ ctx_server.queue_tasks.post(std::move(task));
+ }
+
+ server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+ ctx_server.queue_results.remove_waiting_task_id(task_id);
+
+ if (result->is_error()) {
+ res->error(result->to_json());
+ return res;
+ }
+
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
+ res->ok(result->to_json());
+ return res;
+ }
+
+ std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
+ auto res = std::make_unique<server_res_generator>(ctx_server);
+ if (!ctx_server.params_base.embedding) {
+ res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+ return res;
+ }
+
+ if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+ res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+
+ const json body = json::parse(req.body);
+
+ // for the shape of input/content, see tokenize_input_prompts()
+ json prompt;
+ if (body.count("input") != 0) {
+ prompt = body.at("input");
+ } else if (body.contains("content")) {
+ oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
+ prompt = body.at("content");
+ } else {
+ res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+
+ bool use_base64 = false;
+ if (body.count("encoding_format") != 0) {
+ const std::string& format = body.at("encoding_format");
+ if (format == "base64") {
+ use_base64 = true;
+ } else if (format != "float") {
+ res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+ }
+
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
+ for (const auto & tokens : tokenized_prompts) {
+ // this check is necessary for models that do not add BOS token to the input
+ if (tokens.empty()) {
+ res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
+ return res;
+ }
+ }
+
+ int embd_normalize = 2; // default to Euclidean/L2 norm
+ if (body.count("embd_normalize") != 0) {
+ embd_normalize = body.at("embd_normalize");
+ if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+ SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
+ }
+ }
+
+ // create and queue the task
+ json responses = json::array();
+ server_response_reader rd(ctx_server);
+ {
+ std::vector<server_task> tasks;
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = i;
+ task.tokens = std::move(tokenized_prompts[i]);
+
+ // OAI-compat
+ task.params.oaicompat = oaicompat;
+ task.params.embd_normalize = embd_normalize;
+
+ tasks.push_back(std::move(task));
+ }
+ rd.post_tasks(std::move(tasks));
+ }
+
+ // wait for the results
+ auto all_results = rd.wait_for_all(req.should_stop);
+
+ // collect results
+ if (all_results.is_terminated) {
+ return res; // connection is closed
+ } else if (all_results.error) {
+ res->error(all_results.error->to_json());
+ return res;
+ } else {
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
+ responses.push_back(res->to_json());
+ }
+ }
+
+ // write JSON response
+ json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
+ ? format_embeddings_response_oaicompat(body, responses, use_base64)
+ : json(responses);
+ res->ok(root);
+ return res;
+ }
+};
+
+std::function<void(int)> shutdown_handler;
+std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+inline void signal_handler(int signal) {
+ if (is_terminating.test_and_set()) {
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
+ // this is for better developer experience, we can remove when the server is stable enough
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+ exit(1);
+ }
+
+ shutdown_handler(signal);
+}
+
+// wrapper function that handles exceptions and logs errors
+// this is to make sure handler_t never throws exceptions; instead, it returns an error response
+static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
+ return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
+ std::string message;
+ try {
+ return func(req);
+ } catch (const std::exception & e) {
+ message = e.what();
+ } catch (...) {
+ message = "unknown error";
}
+
+ auto res = std::make_unique<server_http_res>();
+ res->status = 500;
+ try {
+ json error_data = format_error_response(message, ERROR_TYPE_SERVER);
+ res->status = json_value(error_data, "code", 500);
+ res->data = safe_json_to_str({{ "error", error_data }});
+ LOG_WRN("got exception: %s\n", res->data.c_str());
+ } catch (const std::exception & e) {
+ LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
+ res->data = "Internal Server Error";
+ }
+ return res;
+ };
+}
+
+int main(int argc, char ** argv) {
+ // own arguments required by this example
+ common_params params;
+
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
+ return 1;
}
+ // TODO: should we have a separate n_parallel parameter for the server?
+ // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
+ // TODO: this is a common configuration that is suitable for most local use cases
+ // however, overriding the parameters is a bit confusing - figure out something more intuitive
+ if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
+ LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
+
+ params.n_parallel = 4;
+ params.kv_unified = true;
+ }
+
+ common_init();
+
+ // struct that contains llama context and inference
+ server_context ctx_server;
+
+ // Necessary similarity of prompt for slot selection
+ ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
+
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
+ LOG_INF("\n");
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+ LOG_INF("\n");
+
+ server_http_context ctx_http;
+ if (!ctx_http.init(params)) {
+ LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
+ return 1;
+ }
+
+ //
+ // Router
+ //
+
// register API routes
- svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
- svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check)
- svr->Get (params.api_prefix + "/metrics", handle_metrics);
- svr->Get (params.api_prefix + "/props", handle_props);
- svr->Post(params.api_prefix + "/props", handle_props_change);
- svr->Post(params.api_prefix + "/api/show", handle_api_show);
- svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check)
- svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check)
- svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
- svr->Post(params.api_prefix + "/completion", handle_completions); // legacy
- svr->Post(params.api_prefix + "/completions", handle_completions);
- svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai);
- svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions);
- svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions);
- svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint
- svr->Post(params.api_prefix + "/infill", handle_infill);
- svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy
- svr->Post(params.api_prefix + "/embeddings", handle_embeddings);
- svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai);
- svr->Post(params.api_prefix + "/rerank", handle_rerank);
- svr->Post(params.api_prefix + "/reranking", handle_rerank);
- svr->Post(params.api_prefix + "/v1/rerank", handle_rerank);
- svr->Post(params.api_prefix + "/v1/reranking", handle_rerank);
- svr->Post(params.api_prefix + "/tokenize", handle_tokenize);
- svr->Post(params.api_prefix + "/detokenize", handle_detokenize);
- svr->Post(params.api_prefix + "/apply-template", handle_apply_template);
+ server_routes routes(params, ctx_server, ctx_http);
+
+ ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
+ ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
+ ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
+ ctx_http.get ("/props", ex_wrapper(routes.get_props));
+ ctx_http.post("/props", ex_wrapper(routes.post_props));
+ ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
+ ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
+ ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
+ ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
+ ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
+ ctx_http.post("/completions", ex_wrapper(routes.post_completions));
+ ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
+ ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
+ ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
+ ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
+ ctx_http.post("/infill", ex_wrapper(routes.post_infill));
+ ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
+ ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
+ ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
+ ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
+ ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
+ ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
+ ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
+ ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
+ ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
+ ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
// LoRA adapters hotswap
- svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list);
- svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply);
+ ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
+ ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
// Save & load slots
- svr->Get (params.api_prefix + "/slots", handle_slots);
- svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
+ ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
+ ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
//
// Start the server
//
- if (params.n_threads_http < 1) {
- // +2 threads for monitoring endpoints
- params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
- }
- log_data["n_threads_http"] = std::to_string(params.n_threads_http);
- svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
- // clean up function, to be called before exit
- auto clean_up = [&svr, &ctx_server]() {
+ // setup clean up function, to be called before exit
+ auto clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
- svr->stop();
+ ctx_http.stop();
ctx_server.queue_results.terminate();
llama_backend_free();
};
- bool was_bound = false;
- bool is_sock = false;
- if (string_ends_with(std::string(params.hostname), ".sock")) {
- is_sock = true;
- LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
- svr->set_address_family(AF_UNIX);
- // bind_to_port requires a second arg, any value other than 0 should
- // simply get ignored
- was_bound = svr->bind_to_port(params.hostname, 8080);
- } else {
- LOG_INF("%s: binding port with default address family\n", __func__);
- // bind HTTP listen port
- if (params.port == 0) {
- int bound_port = svr->bind_to_any_port(params.hostname);
- if ((was_bound = (bound_port >= 0))) {
- params.port = bound_port;
- }
- } else {
- was_bound = svr->bind_to_port(params.hostname, params.port);
- }
- }
-
- if (!was_bound) {
- LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
+ // start the HTTP server before loading the model to be able to serve /health requests
+ if (!ctx_http.start()) {
clean_up();
+ LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
- // run the HTTP server in a thread
- std::thread t([&]() { svr->listen_after_bind(); });
- svr->wait_until_ready();
-
- LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
-
// load the model
LOG_INF("%s: loading model\n", __func__);
if (!ctx_server.load_model(params)) {
clean_up();
- t.join();
+ if (ctx_http.thread.joinable()) {
+ ctx_http.thread.join();
+ }
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
ctx_server.init();
- state.store(SERVER_STATE_READY);
+ ctx_http.is_ready.store(true);
LOG_INF("%s: model loaded\n", __func__);
- // print sample chat example to make it clear which template is used
- LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
- common_chat_templates_source(ctx_server.chat_templates.get()),
- common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str());
-
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
ctx_server.process_single_task(std::move(task));
});
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
- LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__,
- is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
- string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str());
-
+ LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
+ LOG_INF("%s: starting the main loop...\n", __func__);
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
clean_up();
- t.join();
+ if (ctx_http.thread.joinable()) {
+ ctx_http.thread.join();
+ }
llama_memory_breakdown_print(ctx_server.ctx);
return 0;