]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : implement credentialed CORS (#4514)
authorLaura <redacted>
Thu, 11 Jan 2024 18:02:48 +0000 (19:02 +0100)
committerGitHub <redacted>
Thu, 11 Jan 2024 18:02:48 +0000 (20:02 +0200)
* Implement credentialed CORS according to MDN

* Fix syntax error

* Move validate_api_key up so it is defined before its first usage

examples/server/server.cpp

index 345004fa15f5b172a3d047050eebf918d37575c7..031824e145411c601a4d25c19a0e0485ab94979a 100644 (file)
@@ -2822,9 +2822,15 @@ int main(int argc, char **argv)
 
     std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
 
-    svr.set_default_headers({{"Server", "llama.cpp"},
-                             {"Access-Control-Allow-Origin", "*"},
-                             {"Access-Control-Allow-Headers", "content-type"}});
+    svr.set_default_headers({{"Server", "llama.cpp"}});
+
+    // CORS preflight
+    svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+        res.set_header("Access-Control-Allow-Credentials", "true");
+        res.set_header("Access-Control-Allow-Methods", "POST");
+        res.set_header("Access-Control-Allow-Headers", "*");
+    });
 
     svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
         server_state current_state = state.load();
@@ -2987,9 +2993,9 @@ int main(int argc, char **argv)
                 return false;
             });
 
-    svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res)
+    svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
             {
-                res.set_header("Access-Control-Allow-Origin", "*");
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 json data = {
                     { "user_name",      llama.name_user.c_str() },
                     { "assistant_name", llama.name_assistant.c_str() }
@@ -2999,6 +3005,7 @@ int main(int argc, char **argv)
 
     svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 if (!validate_api_key(req, res)) {
                     return;
                 }
@@ -3066,8 +3073,9 @@ int main(int argc, char **argv)
                 }
             });
 
-    svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
+    svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 std::time_t t = std::time(0);
 
                 json models = {
@@ -3085,9 +3093,11 @@ int main(int argc, char **argv)
                 res.set_content(models.dump(), "application/json; charset=utf-8");
             });
 
+
     // TODO: add mount point without "/v1" prefix -- how?
     svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 if (!validate_api_key(req, res)) {
                     return;
                 }
@@ -3161,6 +3171,7 @@ int main(int argc, char **argv)
 
     svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 if (!validate_api_key(req, res)) {
                     return;
                 }
@@ -3233,6 +3244,7 @@ int main(int argc, char **argv)
 
     svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 const json body = json::parse(req.body);
                 std::vector<llama_token> tokens;
                 if (body.count("content") != 0)
@@ -3245,6 +3257,7 @@ int main(int argc, char **argv)
 
     svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 const json body = json::parse(req.body);
                 std::string content;
                 if (body.count("tokens") != 0)
@@ -3259,6 +3272,7 @@ int main(int argc, char **argv)
 
     svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
             {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 const json body = json::parse(req.body);
                 json prompt;
                 if (body.count("content") != 0)