]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add optional API Key Authentication example (#4441)
authorShadovvBeast <redacted>
Fri, 15 Dec 2023 11:49:01 +0000 (13:49 +0200)
committerGitHub <redacted>
Fri, 15 Dec 2023 11:49:01 +0000 (13:49 +0200)
* Add API key authentication for enhanced server-client security

* server : to snake_case

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/public/completion.js
examples/server/public/index.html
examples/server/server.cpp

index c281f0fbd55350b74ba62e7587ae58f19c8848fb..6e2b99565dc6e48abee02556e3231cfab1488712 100644 (file)
@@ -34,7 +34,8 @@ export async function* llama(prompt, params = {}, config = {}) {
     headers: {
       'Connection': 'keep-alive',
       'Content-Type': 'application/json',
-      'Accept': 'text/event-stream'
+      'Accept': 'text/event-stream',
+      ...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
     },
     signal: controller.signal,
   });
index 451fd4a3be6020a021de3a2d00a4b7d14145fb75..07d779d2008a201fbdae1a5bc0ee8e09f89006cb 100644 (file)
       grammar: '',
       n_probs: 0, // no completion_probabilities,
       image_data: [],
-      cache_prompt: true
+      cache_prompt: true,
+      api_key: ''
     })
 
     /* START: Support for storing prompt templates and parameters in browsers LocalStorage */
             <fieldset>
               ${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
             </fieldset>
+            <fieldset>
+              <label for="api_key">API Key</label>
+              <input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
+            </fieldset>
           </details>
         </form>
       `
index 39d1e83d1857512a6ef3efb9c874661180b04c15..5f93dcb66a4e23412f662df24601c90e2ad83d2e 100644 (file)
@@ -36,6 +36,7 @@ using json = nlohmann::json;
 struct server_params
 {
     std::string hostname = "127.0.0.1";
+    std::string api_key;
     std::string public_path = "examples/server/public";
     int32_t port = 8080;
     int32_t read_timeout = 600;
@@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  --host                ip address to listen (default  (default: %s)\n", sparams.hostname.c_str());
     printf("  --port PORT           port to listen (default  (default: %d)\n", sparams.port);
     printf("  --path PUBLIC_PATH    path from which to serve static files (default %s)\n", sparams.public_path.c_str());
+    printf("  --api-key API_KEY     optional api key to enhance server security. If set, requests must include this key for access.\n");
     printf("  -to N, --timeout N    server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
     printf("  --embedding           enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     printf("  -np N, --parallel N   number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             sparams.public_path = argv[i];
         }
+        else if (arg == "--api-key")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            sparams.api_key = argv[i];
+        }
         else if (arg == "--timeout" || arg == "-to")
         {
             if (++i >= argc)
@@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
 
     httplib::Server svr;
 
+    // Middleware for API key validation
+    auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
+        // If API key is not set, skip validation
+        if (sparams.api_key.empty()) {
+            return true;
+        }
+
+        // Check for API key in the header
+        auto auth_header = req.get_header_value("Authorization");
+        std::string prefix = "Bearer ";
+        if (auth_header.substr(0, prefix.size()) == prefix) {
+            std::string received_api_key = auth_header.substr(prefix.size());
+            if (received_api_key == sparams.api_key) {
+                return true; // API key is valid
+            }
+        }
+
+        // API key is invalid or not provided
+        res.set_content("Unauthorized: Invalid API Key", "text/plain");
+        res.status = 401; // Unauthorized
+
+        LOG_WARNING("Unauthorized: Invalid API Key", {});
+
+        return false;
+    };
+
     svr.set_default_headers({{"Server", "llama.cpp"},
                              {"Access-Control-Allow-Origin", "*"},
                              {"Access-Control-Allow-Headers", "content-type"}});
@@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
                 res.set_content(data.dump(), "application/json");
             });
 
-    svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
+    svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                if (!validate_api_key(req, res)) {
+                    return;
+                }
                 json data = json::parse(req.body);
                 const int task_id = llama.request_completion(data, false, false, -1);
                 if (!json_value(data, "stream", false)) {
@@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
             });
 
     // TODO: add mount point without "/v1" prefix -- how?
-    svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
+    svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                if (!validate_api_key(req, res)) {
+                    return;
+                }
                 json data = oaicompat_completion_params_parse(json::parse(req.body));
 
                 const int task_id = llama.request_completion(data, false, false, -1);
@@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
                 }
             });
 
-    svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
+    svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                if (!validate_api_key(req, res)) {
+                    return;
+                }
                 json data = json::parse(req.body);
                 const int task_id = llama.request_completion(data, true, false, -1);
                 if (!json_value(data, "stream", false)) {
@@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
 
     svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
             {
+                if (res.status == 401)
+                {
+                    res.set_content("Unauthorized", "text/plain");
+                }
                 if (res.status == 400)
                 {
                     res.set_content("Invalid request", "text/plain");
                 }
-                else if (res.status != 500)
+                else if (res.status == 404)
                 {
                     res.set_content("File Not Found", "text/plain");
                     res.status = 404;
@@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
     // to make it ctrl+clickable:
     LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
 
-    LOG_INFO("HTTP server listening", {
-                                          {"hostname", sparams.hostname},
-                                          {"port", sparams.port},
-                                      });
+    std::unordered_map<std::string, std::string> log_data;
+    log_data["hostname"] = sparams.hostname;
+    log_data["port"] = std::to_string(sparams.port);
+
+    if (!sparams.api_key.empty()) {
+        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
+    }
 
+    LOG_INFO("HTTP server listening", log_data);
     // run the HTTP server in a thread - see comment below
     std::thread t([&]()
             {