]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : support for multiple api keys (#4864)
authorMichael Coppola <redacted>
Thu, 11 Jan 2024 17:51:17 +0000 (12:51 -0500)
committerGitHub <redacted>
Thu, 11 Jan 2024 17:51:17 +0000 (19:51 +0200)
* server: added support for multiple api keys, added loading api keys from file

* minor: fix whitespace

* added file error handling to --api-key-file, changed code to better
reflect current style

* server: update README.md for --api-key-file

---------

Co-authored-by: Michael Coppola <redacted>
examples/server/README.md
examples/server/server.cpp

index dc27e72b9d0847f0387b280d9b648fb89dc5a1ae..fd3034b99c3d2a18d8b4f313de1a5f39524fa7a9 100644 (file)
@@ -23,7 +23,8 @@ Command line options:
 -   `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
 -   `--port`: Set the port to listen. Default: `8080`.
 -   `--path`: path from which to serve static files (default examples/server/public)
--   `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
+-   `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
+-   `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
 -   `--embedding`: Enable embedding extraction, Default: disabled.
 -   `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
 -   `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
index 51a4b689f78b1e739b7b19340095b8dc8ebfc176..345004fa15f5b172a3d047050eebf918d37575c7 100644 (file)
@@ -39,7 +39,7 @@ using json = nlohmann::json;
 struct server_params
 {
     std::string hostname = "127.0.0.1";
-    std::string api_key;
+    std::vector<std::string> api_keys;
     std::string public_path = "examples/server/public";
     int32_t port = 8080;
     int32_t read_timeout = 600;
@@ -2021,6 +2021,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  --port PORT           port to listen (default  (default: %d)\n", sparams.port);
     printf("  --path PUBLIC_PATH    path from which to serve static files (default %s)\n", sparams.public_path.c_str());
     printf("  --api-key API_KEY     optional api key to enhance server security. If set, requests must include this key for access.\n");
+    printf("  --api-key-file FNAME  path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
     printf("  -to N, --timeout N    server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
     printf("  --embedding           enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     printf("  -np N, --parallel N   number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2081,7 +2082,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 invalid_param = true;
                 break;
             }
-            sparams.api_key = argv[i];
+            sparams.api_keys.push_back(argv[i]);
+        }
+        else if (arg == "--api-key-file")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            std::ifstream key_file(argv[i]);
+            if (!key_file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            std::string key;
+            while (std::getline(key_file, key)) {
+               if (key.size() > 0) {
+                   sparams.api_keys.push_back(key);
+               }
+            }
+            key_file.close();
         }
         else if (arg == "--timeout" || arg == "-to")
         {
@@ -2881,8 +2903,10 @@ int main(int argc, char **argv)
     log_data["hostname"] = sparams.hostname;
     log_data["port"] = std::to_string(sparams.port);
 
-    if (!sparams.api_key.empty()) {
-        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
+    if (sparams.api_keys.size() == 1) {
+        log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
+    } else if (sparams.api_keys.size() > 1) {
+        log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
     }
 
     LOG_INFO("HTTP server listening", log_data);
@@ -2912,7 +2936,7 @@ int main(int argc, char **argv)
     // Middleware for API key validation
     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
         // If API key is not set, skip validation
-        if (sparams.api_key.empty()) {
+        if (sparams.api_keys.empty()) {
             return true;
         }
 
@@ -2921,7 +2945,7 @@ int main(int argc, char **argv)
         std::string prefix = "Bearer ";
         if (auth_header.substr(0, prefix.size()) == prefix) {
             std::string received_api_key = auth_header.substr(prefix.size());
-            if (received_api_key == sparams.api_key) {
+            if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
                 return true; // API key is valid
             }
         }