struct server_params
{
std::string hostname = "127.0.0.1";
+ std::string api_key;
std::string public_path = "examples/server/public";
int32_t port = 8080;
int32_t read_timeout = 600;
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
+ printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
}
sparams.public_path = argv[i];
}
+ else if (arg == "--api-key")
+ {
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ sparams.api_key = argv[i];
+ }
else if (arg == "--timeout" || arg == "-to")
{
if (++i >= argc)
httplib::Server svr;
+ // Middleware for API key validation
+ auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
+ // If API key is not set, skip validation
+ if (sparams.api_key.empty()) {
+ return true;
+ }
+
+ // Check for API key in the header
+ auto auth_header = req.get_header_value("Authorization");
+ std::string prefix = "Bearer ";
+ if (auth_header.substr(0, prefix.size()) == prefix) {
+ std::string received_api_key = auth_header.substr(prefix.size());
+ if (received_api_key == sparams.api_key) {
+ return true; // API key is valid
+ }
+ }
+
+ // API key is invalid or not provided
+ res.set_content("Unauthorized: Invalid API Key", "text/plain");
+ res.status = 401; // Unauthorized
+
+ LOG_WARNING("Unauthorized: Invalid API Key", {});
+
+ return false;
+ };
+
svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
res.set_content(data.dump(), "application/json");
});
- svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
+ svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ if (!validate_api_key(req, res)) {
+ return;
+ }
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) {
});
// TODO: add mount point without "/v1" prefix -- how?
- svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
+ svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ if (!validate_api_key(req, res)) {
+ return;
+ }
json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false, -1);
}
});
- svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
+ svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ if (!validate_api_key(req, res)) {
+ return;
+ }
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) {
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{
+ if (res.status == 401)
+ {
+ res.set_content("Unauthorized", "text/plain");
+ }
if (res.status == 400)
{
res.set_content("Invalid request", "text/plain");
}
- else if (res.status != 500)
+ else if (res.status == 404)
{
res.set_content("File Not Found", "text/plain");
res.status = 404;
// to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
- LOG_INFO("HTTP server listening", {
- {"hostname", sparams.hostname},
- {"port", sparams.port},
- });
+ std::unordered_map<std::string, std::string> log_data;
+ log_data["hostname"] = sparams.hostname;
+ log_data["port"] = std::to_string(sparams.port);
+
+ if (!sparams.api_key.empty()) {
+ log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
+ }
+ LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{