]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add SSL support (#5926)
authorGabe Goodhart <redacted>
Sat, 9 Mar 2024 09:57:09 +0000 (02:57 -0700)
committerGitHub <redacted>
Sat, 9 Mar 2024 09:57:09 +0000 (11:57 +0200)
* add cmake build toggle to enable ssl support in server

Signed-off-by: Gabe Goodhart <redacted>
* add flags for ssl key/cert files and use SSLServer if set

All SSL setup is hidden behind CPPHTTPLIB_OPENSSL_SUPPORT in the same
way that the base httlib hides the SSL support

Signed-off-by: Gabe Goodhart <redacted>
* Update readme for SSL support in server

Signed-off-by: Gabe Goodhart <redacted>
* Add LLAMA_SERVER_SSL variable setup to top-level Makefile

Signed-off-by: Gabe Goodhart <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
Makefile
examples/server/CMakeLists.txt
examples/server/README.md
examples/server/server.cpp

index efce10bb8bd7e819144f87b878122b762b4b045f..aea96922272980ac5679fa26fef8b97c323e24b7 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE
        MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
 endif
 
+ifdef LLAMA_SERVER_SSL
+       MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT
+       MK_LDFLAGS += -lssl -lcrypto
+endif
 
 ifdef LLAMA_CODE_COVERAGE
        MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase ''
index c21eba634310b861a2d28b62c29a78e3a0615a76..f94de1e99b7e95cff1359a22367d4e27b0cc61ea 100644 (file)
@@ -1,5 +1,6 @@
 set(TARGET server)
 option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
+option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
 add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
 install(TARGETS ${TARGET} RUNTIME)
@@ -7,6 +8,11 @@ target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
 )
 target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
+if (LLAMA_SERVER_SSL)
+    find_package(OpenSSL REQUIRED)
+    target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
+    target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT)
+endif()
 if (WIN32)
     TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
 endif()
index 591f748f84cd78780717c16c31cfd6540702f098..bf8c450b60223827d6d3e571276a2446de6dc68d 100644 (file)
@@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
 - `--log-disable`: Output logs to stdout only, default: enabled.
 - `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json)
 
+**If compiled with `LLAMA_SERVER_SSL=ON`**
+- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
+- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate
+
 ## Build
 
 server is build alongside everything else from the root of the project
@@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project
   cmake --build . --config Release
   ```
 
+## Build with SSL
+
+server can also be built with SSL support using OpenSSL 3
+
+- Using `make`:
+
+  ```bash
+  # NOTE: For non-system openssl, use the following:
+  #   CXXFLAGS="-I /path/to/openssl/include"
+  #   LDFLAGS="-L /path/to/openssl/lib"
+  make LLAMA_SERVER_SSL=true server
+  ```
+
+- Using `CMake`:
+
+  ```bash
+  mkdir build
+  cd build
+  cmake .. -DLLAMA_SERVER_SSL=ON
+  make server
+  ```
+
 ## Quick Start
 
 To get started right away, run the following command, making sure to use the correct path for the model you have:
index 6f44499843a633b914858864a2fcdc5b516cbbd0..c3b87c846a6e25712b162aaaa2889da8cf865d16 100644 (file)
@@ -27,6 +27,7 @@
 #include <mutex>
 #include <thread>
 #include <signal.h>
+#include <memory>
 
 using json = nlohmann::json;
 
@@ -118,6 +119,11 @@ struct server_params {
 
     std::vector<std::string> api_keys;
 
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    std::string ssl_key_file = "";
+    std::string ssl_cert_file = "";
+#endif
+
     bool slots_endpoint   = true;
     bool metrics_endpoint = false;
 };
@@ -2142,6 +2148,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     printf("  --path PUBLIC_PATH        path from which to serve static files (default %s)\n", sparams.public_path.c_str());
     printf("  --api-key API_KEY         optional api key to enhance server security. If set, requests must include this key for access.\n");
     printf("  --api-key-file FNAME      path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    printf("  --ssl-key-file FNAME      path to file a PEM-encoded SSL private key\n");
+    printf("  --ssl-cert-file FNAME     path to file a PEM-encoded SSL certificate\n");
+#endif
     printf("  -to N, --timeout N        server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
     printf("  --embeddings              enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     printf("  -np N, --parallel N       number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2220,7 +2230,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
                }
             }
             key_file.close();
-        } else if (arg == "--timeout" || arg == "-to") {
+
+        }
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+        else if (arg == "--ssl-key-file") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            sparams.ssl_key_file = argv[i];
+        } else if (arg == "--ssl-cert-file") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            sparams.ssl_cert_file = argv[i];
+        }
+#endif
+        else if (arg == "--timeout" || arg == "-to") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -2658,21 +2685,34 @@ int main(int argc, char ** argv) {
         {"system_info",     llama_print_system_info()},
     });
 
-    httplib::Server svr;
+    std::unique_ptr<httplib::Server> svr;
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
+        LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
+        svr.reset(
+            new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
+        );
+    } else {
+        LOG_INFO("Running without SSL", {});
+        svr.reset(new httplib::Server());
+    }
+#else
+    svr.reset(new httplib::Server());
+#endif
 
     std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
 
-    svr.set_default_headers({{"Server", "llama.cpp"}});
+    svr->set_default_headers({{"Server", "llama.cpp"}});
 
     // CORS preflight
-    svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
+    svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin",      req.get_header_value("Origin"));
         res.set_header("Access-Control-Allow-Credentials", "true");
         res.set_header("Access-Control-Allow-Methods",     "POST");
         res.set_header("Access-Control-Allow-Headers",     "*");
     });
 
-    svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
+    svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
         server_state current_state = state.load();
         switch (current_state) {
             case SERVER_STATE_READY:
@@ -2728,7 +2768,7 @@ int main(int argc, char ** argv) {
     });
 
     if (sparams.slots_endpoint) {
-        svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
+        svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
             // request slots data using task queue
             server_task task;
             task.id = ctx_server.queue_tasks.get_new_id();
@@ -2749,7 +2789,7 @@ int main(int argc, char ** argv) {
     }
 
     if (sparams.metrics_endpoint) {
-        svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
+        svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
             // request slots data using task queue
             server_task task;
             task.id = ctx_server.queue_tasks.get_new_id();
@@ -2846,9 +2886,9 @@ int main(int argc, char ** argv) {
         });
     }
 
-    svr.set_logger(log_server_request);
+    svr->set_logger(log_server_request);
 
-    svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
+    svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
         const char fmt[] = "500 Internal Server Error\n%s";
 
         char buf[BUFSIZ];
@@ -2864,7 +2904,7 @@ int main(int argc, char ** argv) {
         res.status = 500;
     });
 
-    svr.set_error_handler([](const httplib::Request &, httplib::Response & res) {
+    svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
         if (res.status == 401) {
             res.set_content("Unauthorized", "text/plain; charset=utf-8");
         }
@@ -2877,16 +2917,16 @@ int main(int argc, char ** argv) {
     });
 
     // set timeouts and change hostname and port
-    svr.set_read_timeout (sparams.read_timeout);
-    svr.set_write_timeout(sparams.write_timeout);
+    svr->set_read_timeout (sparams.read_timeout);
+    svr->set_write_timeout(sparams.write_timeout);
 
-    if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
+    if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
         fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
         return 1;
     }
 
     // Set the base directory for serving static files
-    svr.set_base_dir(sparams.public_path);
+    svr->set_base_dir(sparams.public_path);
 
     std::unordered_map<std::string, std::string> log_data;
 
@@ -2947,30 +2987,30 @@ int main(int argc, char ** argv) {
     };
 
     // this is only called if no index.html is found in the public --path
-    svr.Get("/", [](const httplib::Request &, httplib::Response & res) {
+    svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
         res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
         return false;
     });
 
     // this is only called if no index.js is found in the public --path
-    svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
+    svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
         res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
         return false;
     });
 
     // this is only called if no index.html is found in the public --path
-    svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
+    svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
         res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
         return false;
     });
 
     // this is only called if no index.html is found in the public --path
-    svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
+    svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
         res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
         return false;
     });
 
-    svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = {
             { "user_name",                   ctx_server.name_user.c_str() },
@@ -3062,11 +3102,11 @@ int main(int argc, char ** argv) {
         }
     };
 
-    svr.Post("/completion", completions); // legacy
-    svr.Post("/completions", completions);
-    svr.Post("/v1/completions", completions);
+    svr->Post("/completion", completions); // legacy
+    svr->Post("/completions", completions);
+    svr->Post("/v1/completions", completions);
 
-    svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
+    svr->Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
 
         json models = {
@@ -3161,10 +3201,10 @@ int main(int argc, char ** argv) {
         }
     };
 
-    svr.Post("/chat/completions",    chat_completions);
-    svr.Post("/v1/chat/completions", chat_completions);
+    svr->Post("/chat/completions",    chat_completions);
+    svr->Post("/v1/chat/completions", chat_completions);
 
-    svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+    svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         if (!validate_api_key(req, res)) {
             return;
@@ -3228,11 +3268,11 @@ int main(int argc, char ** argv) {
         }
     });
 
-    svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
+    svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
         return res.set_content("", "application/json; charset=utf-8");
     });
 
-    svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
@@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) {
         return res.set_content(data.dump(), "application/json; charset=utf-8");
     });
 
-    svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
@@ -3258,7 +3298,7 @@ int main(int argc, char ** argv) {
         return res.set_content(data.dump(), "application/json; charset=utf-8");
     });
 
-    svr.Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+    svr->Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         if (!params.embedding) {
             res.status = 501;
@@ -3289,7 +3329,7 @@ int main(int argc, char ** argv) {
         return res.set_content(result.data.dump(), "application/json; charset=utf-8");
     });
 
-    svr.Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+    svr->Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         if (!params.embedding) {
             res.status = 501;
@@ -3360,13 +3400,13 @@ int main(int argc, char ** argv) {
         sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
     }
     log_data["n_threads_http"] =  std::to_string(sparams.n_threads_http);
-    svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
+    svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
 
     LOG_INFO("HTTP server listening", log_data);
 
     // run the HTTP server in a thread - see comment below
     std::thread t([&]() {
-        if (!svr.listen_after_bind()) {
+        if (!svr->listen_after_bind()) {
             state.store(SERVER_STATE_ERROR);
             return 1;
         }
@@ -3407,7 +3447,7 @@ int main(int argc, char ** argv) {
 
     ctx_server.queue_tasks.start_loop();
 
-    svr.stop();
+    svr->stop();
     t.join();
 
     llama_backend_free();