#include <mutex>
#include <thread>
#include <signal.h>
+#include <memory>
using json = nlohmann::json;
std::vector<std::string> api_keys;
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+ std::string ssl_key_file = "";
+ std::string ssl_cert_file = "";
+#endif
+
bool slots_endpoint = true;
bool metrics_endpoint = false;
};
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+ printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n");
+ printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n");
+#endif
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
}
}
key_file.close();
- } else if (arg == "--timeout" || arg == "-to") {
+
+ }
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+ else if (arg == "--ssl-key-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.ssl_key_file = argv[i];
+ } else if (arg == "--ssl-cert-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.ssl_cert_file = argv[i];
+ }
+#endif
+ else if (arg == "--timeout" || arg == "-to") {
if (++i >= argc) {
invalid_param = true;
break;
{"system_info", llama_print_system_info()},
});
- httplib::Server svr;
+ std::unique_ptr<httplib::Server> svr;
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+ if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
+ LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
+ svr.reset(
+ new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
+ );
+ } else {
+ LOG_INFO("Running without SSL", {});
+ svr.reset(new httplib::Server());
+ }
+#else
+ svr.reset(new httplib::Server());
+#endif
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
- svr.set_default_headers({{"Server", "llama.cpp"}});
+ svr->set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight
- svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
+ svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
});
- svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
+ svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
});
if (sparams.slots_endpoint) {
- svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
}
if (sparams.metrics_endpoint) {
- svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
});
}
- svr.set_logger(log_server_request);
+ svr->set_logger(log_server_request);
- svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
+ svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
res.status = 500;
});
- svr.set_error_handler([](const httplib::Request &, httplib::Response & res) {
+ svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 401) {
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
});
// set timeouts and change hostname and port
- svr.set_read_timeout (sparams.read_timeout);
- svr.set_write_timeout(sparams.write_timeout);
+ svr->set_read_timeout (sparams.read_timeout);
+ svr->set_write_timeout(sparams.write_timeout);
- if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
+ if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}
// Set the base directory for serving static files
- svr.set_base_dir(sparams.public_path);
+ svr->set_base_dir(sparams.public_path);
std::unordered_map<std::string, std::string> log_data;
};
// this is only called if no index.html is found in the public --path
- svr.Get("/", [](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
return false;
});
// this is only called if no index.js is found in the public --path
- svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
return false;
});
// this is only called if no index.html is found in the public --path
- svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
return false;
});
// this is only called if no index.html is found in the public --path
- svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
+ svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
return false;
});
- svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", ctx_server.name_user.c_str() },
}
};
- svr.Post("/completion", completions); // legacy
- svr.Post("/completions", completions);
- svr.Post("/v1/completions", completions);
+ svr->Post("/completion", completions); // legacy
+ svr->Post("/completions", completions);
+ svr->Post("/v1/completions", completions);
- svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
+ svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
}
};
- svr.Post("/chat/completions", chat_completions);
- svr.Post("/v1/chat/completions", chat_completions);
+ svr->Post("/chat/completions", chat_completions);
+ svr->Post("/v1/chat/completions", chat_completions);
- svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+ svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
});
- svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
+ svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
return res.set_content("", "application/json; charset=utf-8");
});
- svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
- svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+ svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
- svr.Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+ svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
});
- svr.Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+ svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
- svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
+ svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]() {
- if (!svr.listen_after_bind()) {
+ if (!svr->listen_after_bind()) {
state.store(SERVER_STATE_ERROR);
return 1;
}
ctx_server.queue_tasks.start_loop();
- svr.stop();
+ svr->stop();
t.join();
llama_backend_free();