int32_t n_threads_http = -1;
std::string hostname = "127.0.0.1";
- std::string public_path = "examples/server/public";
+ std::string public_path = "";
std::string chat_template = "";
std::string system_prompt = "";
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
- printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
+ printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
invalid_param = true;
break;
}
- sparams.api_keys.emplace_back(argv[i]);
+ sparams.api_keys.push_back(argv[i]);
} else if (arg == "--api-key-file") {
if (++i >= argc) {
invalid_param = true;
res.set_header("Access-Control-Allow-Headers", "*");
});
- svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
- server_state current_state = state.load();
- switch (current_state) {
- case SERVER_STATE_READY:
- {
- // request slots data using task queue
- server_task task;
- task.id = ctx_server.queue_tasks.get_new_id();
- task.type = SERVER_TASK_TYPE_METRICS;
- task.id_target = -1;
-
- ctx_server.queue_results.add_waiting_task_id(task.id);
- ctx_server.queue_tasks.post(task);
-
- // get the result
- server_task_result result = ctx_server.queue_results.recv(task.id);
- ctx_server.queue_results.remove_waiting_task_id(task.id);
-
- const int n_idle_slots = result.data["idle"];
- const int n_processing_slots = result.data["processing"];
-
- json health = {
- {"status", "ok"},
- {"slots_idle", n_idle_slots},
- {"slots_processing", n_processing_slots}
- };
-
- res.status = 200; // HTTP OK
- if (sparams.slots_endpoint && req.has_param("include_slots")) {
- health["slots"] = result.data["slots"];
- }
-
- if (n_idle_slots == 0) {
- health["status"] = "no slot available";
- if (req.has_param("fail_on_no_slot")) {
- res.status = 503; // HTTP Service Unavailable
- }
- }
-
- res.set_content(health.dump(), "application/json");
- break;
- }
- case SERVER_STATE_LOADING_MODEL:
- {
- res.set_content(R"({"status": "loading model"})", "application/json");
- res.status = 503; // HTTP Service Unavailable
- } break;
- case SERVER_STATE_ERROR:
- {
- res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
- res.status = 500; // HTTP Internal Server Error
- } break;
- }
- });
-
- if (sparams.slots_endpoint) {
- svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
- // request slots data using task queue
- server_task task;
- task.id = ctx_server.queue_tasks.get_new_id();
- task.id_multi = -1;
- task.id_target = -1;
- task.type = SERVER_TASK_TYPE_METRICS;
-
- ctx_server.queue_results.add_waiting_task_id(task.id);
- ctx_server.queue_tasks.post(task);
-
- // get the result
- server_task_result result = ctx_server.queue_results.recv(task.id);
- ctx_server.queue_results.remove_waiting_task_id(task.id);
-
- res.set_content(result.data["slots"].dump(), "application/json");
- res.status = 200; // HTTP OK
- });
- }
-
- if (sparams.metrics_endpoint) {
- svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
- // request slots data using task queue
- server_task task;
- task.id = ctx_server.queue_tasks.get_new_id();
- task.id_multi = -1;
- task.id_target = -1;
- task.type = SERVER_TASK_TYPE_METRICS;
- task.data.push_back({{"reset_bucket", true}});
-
- ctx_server.queue_results.add_waiting_task_id(task.id);
- ctx_server.queue_tasks.post(task);
-
- // get the result
- server_task_result result = ctx_server.queue_results.recv(task.id);
- ctx_server.queue_results.remove_waiting_task_id(task.id);
-
- json data = result.data;
-
- const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
- const uint64_t t_prompt_processing = data["t_prompt_processing"];
-
- const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
- const uint64_t t_tokens_generation = data["t_tokens_generation"];
-
- const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
-
- // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
- json all_metrics_def = json {
- {"counter", {{
- {"name", "prompt_tokens_total"},
- {"help", "Number of prompt tokens processed."},
- {"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
- }, {
- {"name", "prompt_seconds_total"},
- {"help", "Prompt process time"},
- {"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
- }, {
- {"name", "tokens_predicted_total"},
- {"help", "Number of generation tokens processed."},
- {"value", (uint64_t) data["n_tokens_predicted_total"]}
- }, {
- {"name", "tokens_predicted_seconds_total"},
- {"help", "Predict process time"},
- {"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
- }}},
- {"gauge", {{
- {"name", "prompt_tokens_seconds"},
- {"help", "Average prompt throughput in tokens/s."},
- {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
- },{
- {"name", "predicted_tokens_seconds"},
- {"help", "Average generation throughput in tokens/s."},
- {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
- },{
- {"name", "kv_cache_usage_ratio"},
- {"help", "KV-cache usage. 1 means 100 percent usage."},
- {"value", 1. * kv_cache_used_cells / params.n_ctx}
- },{
- {"name", "kv_cache_tokens"},
- {"help", "KV-cache tokens."},
- {"value", (uint64_t) data["kv_cache_tokens_count"]}
- },{
- {"name", "requests_processing"},
- {"help", "Number of request processing."},
- {"value", (uint64_t) data["processing"]}
- },{
- {"name", "requests_deferred"},
- {"help", "Number of request deferred."},
- {"value", (uint64_t) data["deferred"]}
- }}}
- };
-
- std::stringstream prometheus;
-
- for (const auto & el : all_metrics_def.items()) {
- const auto & type = el.key();
- const auto & metrics_def = el.value();
-
- for (const auto & metric_def : metrics_def) {
- const std::string name = metric_def["name"];
- const std::string help = metric_def["help"];
-
- auto value = json_value(metric_def, "value", 0.);
- prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
- << "# TYPE llamacpp:" << name << " " << type << "\n"
- << "llamacpp:" << name << " " << value << "\n";
- }
- }
-
- const int64_t t_start = data["t_start"];
- res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
-
- res.set_content(prometheus.str(), "text/plain; version=0.0.4");
- res.status = 200; // HTTP OK
- });
- }
-
svr->set_logger(log_server_request);
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
return 1;
}
- // Set the base directory for serving static files
- svr->set_base_dir(sparams.public_path);
-
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) {
- log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
+ auto key = sparams.api_keys[0];
+ log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
} else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}
}
}
- // Middleware for API key validation
- auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
+ //
+ // Middlewares
+ //
+
+ auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
+ // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
+ static const std::set<std::string> protected_endpoints = {
+ "/props",
+ "/completion",
+ "/completions",
+ "/v1/completions",
+ "/chat/completions",
+ "/v1/chat/completions",
+ "/infill",
+ "/tokenize",
+ "/detokenize",
+ "/embedding",
+ "/embeddings",
+ "/v1/embeddings",
+ };
+
// If API key is not set, skip validation
if (sparams.api_keys.empty()) {
return true;
}
+ // If path is not in protected_endpoints list, skip validation
+ if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
+ return true;
+ }
+
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
}
// API key is invalid or not provided
+ // TODO: make another middleware for CORS related logic
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized
return false;
};
- // this is only called if no index.html is found in the public --path
- svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
- res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
- return false;
+ // register server middlewares
+ svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
+ if (!middleware_validate_api_key(req, res)) {
+ return httplib::Server::HandlerResponse::Handled;
+ }
+ return httplib::Server::HandlerResponse::Unhandled;
});
- // this is only called if no index.js is found in the public --path
- svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
- res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
- return false;
- });
+ //
+ // Route handlers (or controllers)
+ //
- // this is only called if no index.html is found in the public --path
- svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
- res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
- return false;
- });
+ const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
+ server_state current_state = state.load();
+ switch (current_state) {
+ case SERVER_STATE_READY:
+ {
+ // request slots data using task queue
+ server_task task;
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.type = SERVER_TASK_TYPE_METRICS;
+ task.id_target = -1;
- // this is only called if no index.html is found in the public --path
- svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
- res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
- return false;
- });
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ const int n_idle_slots = result.data["idle"];
+ const int n_processing_slots = result.data["processing"];
+
+ json health = {
+ {"status", "ok"},
+ {"slots_idle", n_idle_slots},
+ {"slots_processing", n_processing_slots}
+ };
+
+ res.status = 200; // HTTP OK
+ if (sparams.slots_endpoint && req.has_param("include_slots")) {
+ health["slots"] = result.data["slots"];
+ }
- svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ if (n_idle_slots == 0) {
+ health["status"] = "no slot available";
+ if (req.has_param("fail_on_no_slot")) {
+ res.status = 503; // HTTP Service Unavailable
+ }
+ }
+
+ res.set_content(health.dump(), "application/json");
+ break;
+ }
+ case SERVER_STATE_LOADING_MODEL:
+ {
+ res.set_content(R"({"status": "loading model"})", "application/json");
+ res.status = 503; // HTTP Service Unavailable
+ } break;
+ case SERVER_STATE_ERROR:
+ {
+ res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
+ res.status = 500; // HTTP Internal Server Error
+ } break;
+ }
+ };
+
+ const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
+ if (!sparams.slots_endpoint) {
+ res.status = 501;
+ res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
+ return;
+ }
+
+ // request slots data using task queue
+ server_task task;
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.id_multi = -1;
+ task.id_target = -1;
+ task.type = SERVER_TASK_TYPE_METRICS;
+
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ res.set_content(result.data["slots"].dump(), "application/json");
+ res.status = 200; // HTTP OK
+ };
+
+ const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
+ if (!sparams.metrics_endpoint) {
+ res.status = 501;
+ res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
+ return;
+ }
+
+ // request slots data using task queue
+ server_task task;
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.id_multi = -1;
+ task.id_target = -1;
+ task.type = SERVER_TASK_TYPE_METRICS;
+ task.data.push_back({{"reset_bucket", true}});
+
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+ ctx_server.queue_tasks.post(task);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ json data = result.data;
+
+ const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
+ const uint64_t t_prompt_processing = data["t_prompt_processing"];
+
+ const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
+ const uint64_t t_tokens_generation = data["t_tokens_generation"];
+
+ const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
+
+ // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
+ json all_metrics_def = json {
+ {"counter", {{
+ {"name", "prompt_tokens_total"},
+ {"help", "Number of prompt tokens processed."},
+ {"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
+ }, {
+ {"name", "prompt_seconds_total"},
+ {"help", "Prompt process time"},
+ {"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
+ }, {
+ {"name", "tokens_predicted_total"},
+ {"help", "Number of generation tokens processed."},
+ {"value", (uint64_t) data["n_tokens_predicted_total"]}
+ }, {
+ {"name", "tokens_predicted_seconds_total"},
+ {"help", "Predict process time"},
+ {"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
+ }}},
+ {"gauge", {{
+ {"name", "prompt_tokens_seconds"},
+ {"help", "Average prompt throughput in tokens/s."},
+ {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
+ },{
+ {"name", "predicted_tokens_seconds"},
+ {"help", "Average generation throughput in tokens/s."},
+ {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
+ },{
+ {"name", "kv_cache_usage_ratio"},
+ {"help", "KV-cache usage. 1 means 100 percent usage."},
+ {"value", 1. * kv_cache_used_cells / params.n_ctx}
+ },{
+ {"name", "kv_cache_tokens"},
+ {"help", "KV-cache tokens."},
+ {"value", (uint64_t) data["kv_cache_tokens_count"]}
+ },{
+ {"name", "requests_processing"},
+ {"help", "Number of request processing."},
+ {"value", (uint64_t) data["processing"]}
+ },{
+ {"name", "requests_deferred"},
+ {"help", "Number of request deferred."},
+ {"value", (uint64_t) data["deferred"]}
+ }}}
+ };
+
+ std::stringstream prometheus;
+
+ for (const auto & el : all_metrics_def.items()) {
+ const auto & type = el.key();
+ const auto & metrics_def = el.value();
+
+ for (const auto & metric_def : metrics_def) {
+ const std::string name = metric_def["name"];
+ const std::string help = metric_def["help"];
+
+ auto value = json_value(metric_def, "value", 0.);
+ prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
+ << "# TYPE llamacpp:" << name << " " << type << "\n"
+ << "llamacpp:" << name << " " << value << "\n";
+ }
+ }
+
+ const int64_t t_start = data["t_start"];
+ res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
+
+ res.set_content(prometheus.str(), "text/plain; version=0.0.4");
+ res.status = 200; // HTTP OK
+ };
+
+ const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", ctx_server.name_user.c_str() },
};
res.set_content(data.dump(), "application/json; charset=utf-8");
- });
+ };
- const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- if (!validate_api_key(req, res)) {
- return;
- }
json data = json::parse(req.body);
}
};
- svr->Post("/completion", completions); // legacy
- svr->Post("/completions", completions);
- svr->Post("/v1/completions", completions);
-
- svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
};
res.set_content(models.dump(), "application/json; charset=utf-8");
- });
+ };
- const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- if (!validate_api_key(req, res)) {
- return;
- }
-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id();
}
};
- svr->Post("/chat/completions", chat_completions);
- svr->Post("/v1/chat/completions", chat_completions);
-
- svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- if (!validate_api_key(req, res)) {
- return;
- }
json data = json::parse(req.body);
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
- });
-
- svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
- return res.set_content("", "application/json; charset=utf-8");
- });
+ };
- svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8");
- });
+ };
- svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8");
- });
+ };
- svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+ const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
}
const json body = json::parse(req.body);
+ bool is_openai = false;
- json prompt;
- if (body.count("content") != 0) {
- prompt = body["content"];
+ // an input prompt can string or a list of tokens (integer)
+ std::vector<json> prompts;
+ if (body.count("input") != 0) {
+ is_openai = true;
+ if (body["input"].is_array()) {
+ // support multiple prompts
+ for (const json & elem : body["input"]) {
+ prompts.push_back(elem);
+ }
+ } else {
+ // single input prompt
+ prompts.push_back(body["input"]);
+ }
+ } else if (body.count("content") != 0) {
+ // only support single prompt here
+ std::string content = body["content"];
+ prompts.push_back(content);
} else {
- prompt = "";
- }
-
- // create and queue the task
- const int id_task = ctx_server.queue_tasks.get_new_id();
-
- ctx_server.queue_results.add_waiting_task_id(id_task);
- ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
-
- // get the result
- server_task_result result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
-
- // send the result
- return res.set_content(result.data.dump(), "application/json; charset=utf-8");
- });
-
- svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- if (!params.embedding) {
- res.status = 501;
- res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
- return;
+ // TODO @ngxson : should return an error here
+ prompts.push_back("");
}
- const json body = json::parse(req.body);
-
- json prompt;
- if (body.count("input") != 0) {
- prompt = body["input"];
- if (prompt.is_array()) {
- json data = json::array();
-
- int i = 0;
- for (const json & elem : prompt) {
- const int id_task = ctx_server.queue_tasks.get_new_id();
-
- ctx_server.queue_results.add_waiting_task_id(id_task);
- ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
-
- // get the result
- server_task_result result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
-
- json embedding = json{
- {"embedding", json_value(result.data, "embedding", json::array())},
- {"index", i++},
- {"object", "embedding"}
- };
+ // process all prompts
+ json responses = json::array();
+ for (auto & prompt : prompts) {
+ // TODO @ngxson : maybe support multitask for this endpoint?
+ // create and queue the task
+ const int id_task = ctx_server.queue_tasks.get_new_id();
- data.push_back(embedding);
- }
-
- json result = format_embeddings_response_oaicompat(body, data);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
- return res.set_content(result.dump(), "application/json; charset=utf-8");
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ responses.push_back(result.data);
+ }
+
+ // write JSON response
+ json root;
+ if (is_openai) {
+ json res_oai = json::array();
+ int i = 0;
+ for (auto & elem : responses) {
+ res_oai.push_back(json{
+ {"embedding", json_value(elem, "embedding", json::array())},
+ {"index", i++},
+ {"object", "embedding"}
+ });
}
+ root = format_embeddings_response_oaicompat(body, res_oai);
} else {
- prompt = "";
+ root = responses[0];
}
+ return res.set_content(root.dump(), "application/json; charset=utf-8");
+ };
- // create and queue the task
- const int id_task = ctx_server.queue_tasks.get_new_id();
-
- ctx_server.queue_results.add_waiting_task_id(id_task);
- ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
+ //
+ // Router
+ //
- // get the result
- server_task_result result = ctx_server.queue_results.recv(id_task);
- ctx_server.queue_results.remove_waiting_task_id(id_task);
-
- json data = json::array({json{
- {"embedding", json_value(result.data, "embedding", json::array())},
- {"index", 0},
- {"object", "embedding"}
- }}
- );
+ // register static assets routes
+ if (!sparams.public_path.empty()) {
+ // Set the base directory for serving static files
+ svr->set_base_dir(sparams.public_path);
+ }
- json root = format_embeddings_response_oaicompat(body, data);
+ // using embedded static files
+ auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
+ return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
+ res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
+ return false;
+ };
+ };
- return res.set_content(root.dump(), "application/json; charset=utf-8");
+ svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
+ // TODO @ngxson : I have no idea what it is... maybe this is redundant?
+ return res.set_content("", "application/json; charset=utf-8");
});
+ svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
+ svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
+ svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
+ svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
+ json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
+
+ // register API routes
+ svr->Get ("/health", handle_health);
+ svr->Get ("/slots", handle_slots);
+ svr->Get ("/metrics", handle_metrics);
+ svr->Get ("/props", handle_props);
+ svr->Get ("/v1/models", handle_models);
+ svr->Post("/completion", handle_completions); // legacy
+ svr->Post("/completions", handle_completions);
+ svr->Post("/v1/completions", handle_completions);
+ svr->Post("/chat/completions", handle_chat_completions);
+ svr->Post("/v1/chat/completions", handle_chat_completions);
+ svr->Post("/infill", handle_infill);
+ svr->Post("/embedding", handle_embeddings); // legacy
+ svr->Post("/embeddings", handle_embeddings);
+ svr->Post("/v1/embeddings", handle_embeddings);
+ svr->Post("/tokenize", handle_tokenize);
+ svr->Post("/detokenize", handle_detokenize);
+ //
+ // Start the server
+ //
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);