]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Server: reorganize some http logic (#5939)
authorXuan Son Nguyen <redacted>
Sat, 9 Mar 2024 10:27:53 +0000 (11:27 +0100)
committerGitHub <redacted>
Sat, 9 Mar 2024 10:27:53 +0000 (11:27 +0100)
* refactor static file handler

* use set_pre_routing_handler for validate_api_key

* merge embedding handlers

* correct http verb for endpoints

* fix embedding response

* fix test case CORS Options

* fix code style

examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/security.feature
examples/server/tests/features/steps/steps.py

index bf8c450b60223827d6d3e571276a2446de6dc68d..3abb1abe3b92bffaa9656013a675e5b98114e15e 100644 (file)
@@ -42,7 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
 - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
 - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
 - `--port`: Set the port to listen. Default: `8080`.
-- `--path`: path from which to serve static files (default examples/server/public)
+- `--path`: path from which to serve static files (default: disabled)
 - `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
 - `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
 - `--embedding`: Enable embedding extraction, Default: disabled.
@@ -558,7 +558,7 @@ The HTTP server supports OAI-like API
 
 ### Extending or building alternative Web Front End
 
-The default location for the static files is `examples/server/public`. You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
+You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
 
 Read the documentation in `/completion.js` to see convenient ways to access llama.
 
index c3b87c846a6e25712b162aaaa2889da8cf865d16..6e0f8328cdf5a6e2c42f1f532d30b9b07a272528 100644 (file)
@@ -113,7 +113,7 @@ struct server_params {
     int32_t n_threads_http = -1;
 
     std::string hostname      = "127.0.0.1";
-    std::string public_path   = "examples/server/public";
+    std::string public_path   = "";
     std::string chat_template = "";
     std::string system_prompt = "";
 
@@ -2145,7 +2145,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     printf("  --lora-base FNAME         optional model to use as a base for the layers modified by the LoRA adapter\n");
     printf("  --host                    ip address to listen (default  (default: %s)\n", sparams.hostname.c_str());
     printf("  --port PORT               port to listen (default  (default: %d)\n", sparams.port);
-    printf("  --path PUBLIC_PATH        path from which to serve static files (default %s)\n", sparams.public_path.c_str());
+    printf("  --path PUBLIC_PATH        path from which to serve static files (default: disabled)\n");
     printf("  --api-key API_KEY         optional api key to enhance server security. If set, requests must include this key for access.\n");
     printf("  --api-key-file FNAME      path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@@ -2211,7 +2211,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
                 invalid_param = true;
                 break;
             }
-            sparams.api_keys.emplace_back(argv[i]);
+            sparams.api_keys.push_back(argv[i]);
         } else if (arg == "--api-key-file") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -2712,180 +2712,6 @@ int main(int argc, char ** argv) {
         res.set_header("Access-Control-Allow-Headers",     "*");
     });
 
-    svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
-        server_state current_state = state.load();
-        switch (current_state) {
-            case SERVER_STATE_READY:
-                {
-                    // request slots data using task queue
-                    server_task task;
-                    task.id   = ctx_server.queue_tasks.get_new_id();
-                    task.type = SERVER_TASK_TYPE_METRICS;
-                    task.id_target = -1;
-
-                    ctx_server.queue_results.add_waiting_task_id(task.id);
-                    ctx_server.queue_tasks.post(task);
-
-                    // get the result
-                    server_task_result result = ctx_server.queue_results.recv(task.id);
-                    ctx_server.queue_results.remove_waiting_task_id(task.id);
-
-                    const int n_idle_slots       = result.data["idle"];
-                    const int n_processing_slots = result.data["processing"];
-
-                    json health = {
-                        {"status",           "ok"},
-                        {"slots_idle",       n_idle_slots},
-                        {"slots_processing", n_processing_slots}
-                    };
-
-                    res.status = 200; // HTTP OK
-                    if (sparams.slots_endpoint && req.has_param("include_slots")) {
-                        health["slots"] = result.data["slots"];
-                    }
-
-                    if (n_idle_slots == 0) {
-                        health["status"] = "no slot available";
-                        if (req.has_param("fail_on_no_slot")) {
-                            res.status = 503; // HTTP Service Unavailable
-                        }
-                    }
-
-                    res.set_content(health.dump(), "application/json");
-                    break;
-                }
-            case SERVER_STATE_LOADING_MODEL:
-                {
-                    res.set_content(R"({"status": "loading model"})", "application/json");
-                    res.status = 503; // HTTP Service Unavailable
-                } break;
-            case SERVER_STATE_ERROR:
-                {
-                    res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
-                    res.status = 500; // HTTP Internal Server Error
-                } break;
-        }
-    });
-
-    if (sparams.slots_endpoint) {
-        svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
-            // request slots data using task queue
-            server_task task;
-            task.id = ctx_server.queue_tasks.get_new_id();
-            task.id_multi  = -1;
-            task.id_target = -1;
-            task.type = SERVER_TASK_TYPE_METRICS;
-
-            ctx_server.queue_results.add_waiting_task_id(task.id);
-            ctx_server.queue_tasks.post(task);
-
-            // get the result
-            server_task_result result = ctx_server.queue_results.recv(task.id);
-            ctx_server.queue_results.remove_waiting_task_id(task.id);
-
-            res.set_content(result.data["slots"].dump(), "application/json");
-            res.status = 200; // HTTP OK
-        });
-    }
-
-    if (sparams.metrics_endpoint) {
-        svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
-            // request slots data using task queue
-            server_task task;
-            task.id = ctx_server.queue_tasks.get_new_id();
-            task.id_multi  = -1;
-            task.id_target = -1;
-            task.type = SERVER_TASK_TYPE_METRICS;
-            task.data.push_back({{"reset_bucket", true}});
-
-            ctx_server.queue_results.add_waiting_task_id(task.id);
-            ctx_server.queue_tasks.post(task);
-
-            // get the result
-            server_task_result result = ctx_server.queue_results.recv(task.id);
-            ctx_server.queue_results.remove_waiting_task_id(task.id);
-
-            json data = result.data;
-
-            const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
-            const uint64_t t_prompt_processing       = data["t_prompt_processing"];
-
-            const uint64_t n_tokens_predicted  = data["n_tokens_predicted"];
-            const uint64_t t_tokens_generation = data["t_tokens_generation"];
-
-            const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
-
-            // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
-            json all_metrics_def = json {
-                {"counter", {{
-                        {"name",  "prompt_tokens_total"},
-                        {"help",  "Number of prompt tokens processed."},
-                        {"value",  (uint64_t) data["n_prompt_tokens_processed_total"]}
-                }, {
-                        {"name",  "prompt_seconds_total"},
-                        {"help",  "Prompt process time"},
-                        {"value",  (uint64_t) data["t_prompt_processing_total"] / 1.e3}
-                }, {
-                        {"name",  "tokens_predicted_total"},
-                        {"help",  "Number of generation tokens processed."},
-                        {"value",  (uint64_t) data["n_tokens_predicted_total"]}
-                }, {
-                        {"name",  "tokens_predicted_seconds_total"},
-                        {"help",  "Predict process time"},
-                        {"value",  (uint64_t) data["t_tokens_generation_total"] / 1.e3}
-                }}},
-                {"gauge", {{
-                        {"name",  "prompt_tokens_seconds"},
-                        {"help",  "Average prompt throughput in tokens/s."},
-                        {"value",  n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
-                },{
-                        {"name",  "predicted_tokens_seconds"},
-                        {"help",  "Average generation throughput in tokens/s."},
-                        {"value",  n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
-                },{
-                        {"name",  "kv_cache_usage_ratio"},
-                        {"help",  "KV-cache usage. 1 means 100 percent usage."},
-                        {"value",  1. * kv_cache_used_cells / params.n_ctx}
-                },{
-                        {"name",  "kv_cache_tokens"},
-                        {"help",  "KV-cache tokens."},
-                        {"value",  (uint64_t) data["kv_cache_tokens_count"]}
-                },{
-                        {"name",  "requests_processing"},
-                        {"help",  "Number of request processing."},
-                        {"value",  (uint64_t) data["processing"]}
-                },{
-                        {"name",  "requests_deferred"},
-                        {"help",  "Number of request deferred."},
-                        {"value",  (uint64_t) data["deferred"]}
-                }}}
-            };
-
-            std::stringstream prometheus;
-
-            for (const auto & el : all_metrics_def.items()) {
-                const auto & type        = el.key();
-                const auto & metrics_def = el.value();
-
-                for (const auto & metric_def : metrics_def) {
-                    const std::string name = metric_def["name"];
-                    const std::string help = metric_def["help"];
-
-                    auto value = json_value(metric_def, "value", 0.);
-                    prometheus << "# HELP llamacpp:" << name << " " << help  << "\n"
-                               << "# TYPE llamacpp:" << name << " " << type  << "\n"
-                               << "llamacpp:"        << name << " " << value << "\n";
-                }
-            }
-
-            const int64_t t_start = data["t_start"];
-            res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
-
-            res.set_content(prometheus.str(), "text/plain; version=0.0.4");
-            res.status = 200; // HTTP OK
-        });
-    }
-
     svr->set_logger(log_server_request);
 
     svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
@@ -2925,16 +2751,14 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    // Set the base directory for serving static files
-    svr->set_base_dir(sparams.public_path);
-
     std::unordered_map<std::string, std::string> log_data;
 
     log_data["hostname"] = sparams.hostname;
     log_data["port"]     = std::to_string(sparams.port);
 
     if (sparams.api_keys.size() == 1) {
-        log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
+        auto key = sparams.api_keys[0];
+        log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
     } else if (sparams.api_keys.size() > 1) {
         log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
     }
@@ -2959,13 +2783,37 @@ int main(int argc, char ** argv) {
         }
     }
 
-    // Middleware for API key validation
-    auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
+    //
+    // Middlewares
+    //
+
+    auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
+        // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
+        static const std::set<std::string> protected_endpoints = {
+            "/props",
+            "/completion",
+            "/completions",
+            "/v1/completions",
+            "/chat/completions",
+            "/v1/chat/completions",
+            "/infill",
+            "/tokenize",
+            "/detokenize",
+            "/embedding",
+            "/embeddings",
+            "/v1/embeddings",
+        };
+
         // If API key is not set, skip validation
         if (sparams.api_keys.empty()) {
             return true;
         }
 
+        // If path is not in protected_endpoints list, skip validation
+        if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
+            return true;
+        }
+
         // Check for API key in the header
         auto auth_header = req.get_header_value("Authorization");
 
@@ -2978,6 +2826,8 @@ int main(int argc, char ** argv) {
         }
 
         // API key is invalid or not provided
+        // TODO: make another middleware for CORS related logic
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
         res.status = 401; // Unauthorized
 
@@ -2986,31 +2836,201 @@ int main(int argc, char ** argv) {
         return false;
     };
 
-    // this is only called if no index.html is found in the public --path
-    svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
-        res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
-        return false;
+    // register server middlewares
+    svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
+        if (!middleware_validate_api_key(req, res)) {
+            return httplib::Server::HandlerResponse::Handled;
+        }
+        return httplib::Server::HandlerResponse::Unhandled;
     });
 
-    // this is only called if no index.js is found in the public --path
-    svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
-        res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
-        return false;
-    });
+    //
+    // Route handlers (or controllers)
+    //
 
-    // this is only called if no index.html is found in the public --path
-    svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
-        res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
-        return false;
-    });
+    const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
+        server_state current_state = state.load();
+        switch (current_state) {
+            case SERVER_STATE_READY:
+                {
+                    // request slots data using task queue
+                    server_task task;
+                    task.id   = ctx_server.queue_tasks.get_new_id();
+                    task.type = SERVER_TASK_TYPE_METRICS;
+                    task.id_target = -1;
 
-    // this is only called if no index.html is found in the public --path
-    svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
-        res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
-        return false;
-    });
+                    ctx_server.queue_results.add_waiting_task_id(task.id);
+                    ctx_server.queue_tasks.post(task);
+
+                    // get the result
+                    server_task_result result = ctx_server.queue_results.recv(task.id);
+                    ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+                    const int n_idle_slots       = result.data["idle"];
+                    const int n_processing_slots = result.data["processing"];
+
+                    json health = {
+                        {"status",           "ok"},
+                        {"slots_idle",       n_idle_slots},
+                        {"slots_processing", n_processing_slots}
+                    };
+
+                    res.status = 200; // HTTP OK
+                    if (sparams.slots_endpoint && req.has_param("include_slots")) {
+                        health["slots"] = result.data["slots"];
+                    }
 
-    svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+                    if (n_idle_slots == 0) {
+                        health["status"] = "no slot available";
+                        if (req.has_param("fail_on_no_slot")) {
+                            res.status = 503; // HTTP Service Unavailable
+                        }
+                    }
+
+                    res.set_content(health.dump(), "application/json");
+                    break;
+                }
+            case SERVER_STATE_LOADING_MODEL:
+                {
+                    res.set_content(R"({"status": "loading model"})", "application/json");
+                    res.status = 503; // HTTP Service Unavailable
+                } break;
+            case SERVER_STATE_ERROR:
+                {
+                    res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
+                    res.status = 500; // HTTP Internal Server Error
+                } break;
+        }
+    };
+
+    const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
+        if (!sparams.slots_endpoint) {
+            res.status = 501;
+            res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
+            return;
+        }
+
+        // request slots data using task queue
+        server_task task;
+        task.id = ctx_server.queue_tasks.get_new_id();
+        task.id_multi  = -1;
+        task.id_target = -1;
+        task.type = SERVER_TASK_TYPE_METRICS;
+
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
+
+        // get the result
+        server_task_result result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+        res.set_content(result.data["slots"].dump(), "application/json");
+        res.status = 200; // HTTP OK
+    };
+
+    const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
+        if (!sparams.metrics_endpoint) {
+            res.status = 501;
+            res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
+            return;
+        }
+
+        // request slots data using task queue
+        server_task task;
+        task.id = ctx_server.queue_tasks.get_new_id();
+        task.id_multi  = -1;
+        task.id_target = -1;
+        task.type = SERVER_TASK_TYPE_METRICS;
+        task.data.push_back({{"reset_bucket", true}});
+
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
+
+        // get the result
+        server_task_result result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+        json data = result.data;
+
+        const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
+        const uint64_t t_prompt_processing       = data["t_prompt_processing"];
+
+        const uint64_t n_tokens_predicted  = data["n_tokens_predicted"];
+        const uint64_t t_tokens_generation = data["t_tokens_generation"];
+
+        const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
+
+        // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
+        json all_metrics_def = json {
+            {"counter", {{
+                    {"name",  "prompt_tokens_total"},
+                    {"help",  "Number of prompt tokens processed."},
+                    {"value",  (uint64_t) data["n_prompt_tokens_processed_total"]}
+            }, {
+                    {"name",  "prompt_seconds_total"},
+                    {"help",  "Prompt process time"},
+                    {"value",  (uint64_t) data["t_prompt_processing_total"] / 1.e3}
+            }, {
+                    {"name",  "tokens_predicted_total"},
+                    {"help",  "Number of generation tokens processed."},
+                    {"value",  (uint64_t) data["n_tokens_predicted_total"]}
+            }, {
+                    {"name",  "tokens_predicted_seconds_total"},
+                    {"help",  "Predict process time"},
+                    {"value",  (uint64_t) data["t_tokens_generation_total"] / 1.e3}
+            }}},
+            {"gauge", {{
+                    {"name",  "prompt_tokens_seconds"},
+                    {"help",  "Average prompt throughput in tokens/s."},
+                    {"value",  n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
+            },{
+                    {"name",  "predicted_tokens_seconds"},
+                    {"help",  "Average generation throughput in tokens/s."},
+                    {"value",  n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
+            },{
+                    {"name",  "kv_cache_usage_ratio"},
+                    {"help",  "KV-cache usage. 1 means 100 percent usage."},
+                    {"value",  1. * kv_cache_used_cells / params.n_ctx}
+            },{
+                    {"name",  "kv_cache_tokens"},
+                    {"help",  "KV-cache tokens."},
+                    {"value",  (uint64_t) data["kv_cache_tokens_count"]}
+            },{
+                    {"name",  "requests_processing"},
+                    {"help",  "Number of request processing."},
+                    {"value",  (uint64_t) data["processing"]}
+            },{
+                    {"name",  "requests_deferred"},
+                    {"help",  "Number of request deferred."},
+                    {"value",  (uint64_t) data["deferred"]}
+            }}}
+        };
+
+        std::stringstream prometheus;
+
+        for (const auto & el : all_metrics_def.items()) {
+            const auto & type        = el.key();
+            const auto & metrics_def = el.value();
+
+            for (const auto & metric_def : metrics_def) {
+                const std::string name = metric_def["name"];
+                const std::string help = metric_def["help"];
+
+                auto value = json_value(metric_def, "value", 0.);
+                prometheus << "# HELP llamacpp:" << name << " " << help  << "\n"
+                            << "# TYPE llamacpp:" << name << " " << type  << "\n"
+                            << "llamacpp:"        << name << " " << value << "\n";
+            }
+        }
+
+        const int64_t t_start = data["t_start"];
+        res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
+
+        res.set_content(prometheus.str(), "text/plain; version=0.0.4");
+        res.status = 200; // HTTP OK
+    };
+
+    const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = {
             { "user_name",                   ctx_server.name_user.c_str() },
@@ -3020,13 +3040,10 @@ int main(int argc, char ** argv) {
         };
 
         res.set_content(data.dump(), "application/json; charset=utf-8");
-    });
+    };
 
-    const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-        if (!validate_api_key(req, res)) {
-            return;
-        }
 
         json data = json::parse(req.body);
 
@@ -3102,11 +3119,7 @@ int main(int argc, char ** argv) {
         }
     };
 
-    svr->Post("/completion", completions); // legacy
-    svr->Post("/completions", completions);
-    svr->Post("/v1/completions", completions);
-
-    svr->Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
 
         json models = {
@@ -3123,14 +3136,10 @@ int main(int argc, char ** argv) {
         };
 
         res.set_content(models.dump(), "application/json; charset=utf-8");
-    });
+    };
 
-    const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-        if (!validate_api_key(req, res)) {
-            return;
-        }
-
         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
 
         const int id_task = ctx_server.queue_tasks.get_new_id();
@@ -3201,14 +3210,8 @@ int main(int argc, char ** argv) {
         }
     };
 
-    svr->Post("/chat/completions",    chat_completions);
-    svr->Post("/v1/chat/completions", chat_completions);
-
-    svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-        if (!validate_api_key(req, res)) {
-            return;
-        }
 
         json data = json::parse(req.body);
 
@@ -3266,13 +3269,9 @@ int main(int argc, char ** argv) {
 
             res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
         }
-    });
-
-    svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
-        return res.set_content("", "application/json; charset=utf-8");
-    });
+    };
 
-    svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
@@ -3282,9 +3281,9 @@ int main(int argc, char ** argv) {
         }
         const json data = format_tokenizer_response(tokens);
         return res.set_content(data.dump(), "application/json; charset=utf-8");
-    });
+    };
 
-    svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
@@ -3296,9 +3295,9 @@ int main(int argc, char ** argv) {
 
         const json data = format_detokenized_response(content);
         return res.set_content(data.dump(), "application/json; charset=utf-8");
-    });
+    };
 
-    svr->Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_embeddings = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         if (!params.embedding) {
             res.status = 501;
@@ -3307,94 +3306,114 @@ int main(int argc, char ** argv) {
         }
 
         const json body = json::parse(req.body);
+        bool is_openai = false;
 
-        json prompt;
-        if (body.count("content") != 0) {
-            prompt = body["content"];
+        // an input prompt can string or a list of tokens (integer)
+        std::vector<json> prompts;
+        if (body.count("input") != 0) {
+            is_openai = true;
+            if (body["input"].is_array()) {
+                // support multiple prompts
+                for (const json & elem : body["input"]) {
+                    prompts.push_back(elem);
+                }
+            } else {
+                // single input prompt
+                prompts.push_back(body["input"]);
+            }
+        } else if (body.count("content") != 0) {
+            // only support single prompt here
+            std::string content = body["content"];
+            prompts.push_back(content);
         } else {
-            prompt = "";
-        }
-
-        // create and queue the task
-        const int id_task = ctx_server.queue_tasks.get_new_id();
-
-        ctx_server.queue_results.add_waiting_task_id(id_task);
-        ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
-
-        // get the result
-        server_task_result result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
-
-        // send the result
-        return res.set_content(result.data.dump(), "application/json; charset=utf-8");
-    });
-
-    svr->Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-        if (!params.embedding) {
-            res.status = 501;
-            res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
-            return;
+            // TODO @ngxson : should return an error here
+            prompts.push_back("");
         }
 
-        const json body = json::parse(req.body);
-
-        json prompt;
-        if (body.count("input") != 0) {
-            prompt = body["input"];
-            if (prompt.is_array()) {
-                json data = json::array();
-
-                int i = 0;
-                for (const json & elem : prompt) {
-                    const int id_task = ctx_server.queue_tasks.get_new_id();
-
-                    ctx_server.queue_results.add_waiting_task_id(id_task);
-                    ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
-
-                    // get the result
-                    server_task_result result = ctx_server.queue_results.recv(id_task);
-                    ctx_server.queue_results.remove_waiting_task_id(id_task);
-
-                    json embedding = json{
-                        {"embedding", json_value(result.data, "embedding", json::array())},
-                        {"index",     i++},
-                        {"object",    "embedding"}
-                    };
+        // process all prompts
+        json responses = json::array();
+        for (auto & prompt : prompts) {
+            // TODO @ngxson : maybe support multitask for this endpoint?
+            // create and queue the task
+            const int id_task = ctx_server.queue_tasks.get_new_id();
 
-                    data.push_back(embedding);
-                }
-
-                json result = format_embeddings_response_oaicompat(body, data);
+            ctx_server.queue_results.add_waiting_task_id(id_task);
+            ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
 
-                return res.set_content(result.dump(), "application/json; charset=utf-8");
+            // get the result
+            server_task_result result = ctx_server.queue_results.recv(id_task);
+            ctx_server.queue_results.remove_waiting_task_id(id_task);
+            responses.push_back(result.data);
+        }
+
+        // write JSON response
+        json root;
+        if (is_openai) {
+            json res_oai = json::array();
+            int i = 0;
+            for (auto & elem : responses) {
+                res_oai.push_back(json{
+                    {"embedding", json_value(elem, "embedding", json::array())},
+                    {"index",     i++},
+                    {"object",    "embedding"}
+                });
             }
+            root = format_embeddings_response_oaicompat(body, res_oai);
         } else {
-            prompt = "";
+            root = responses[0];
         }
+        return res.set_content(root.dump(), "application/json; charset=utf-8");
+    };
 
-        // create and queue the task
-        const int id_task = ctx_server.queue_tasks.get_new_id();
-
-        ctx_server.queue_results.add_waiting_task_id(id_task);
-        ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
+    //
+    // Router
+    //
 
-        // get the result
-        server_task_result result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
-
-        json data = json::array({json{
-            {"embedding", json_value(result.data, "embedding", json::array())},
-            {"index",     0},
-            {"object",    "embedding"}
-        }}
-        );
+    // register static assets routes
+    if (!sparams.public_path.empty()) {
+        // Set the base directory for serving static files
+        svr->set_base_dir(sparams.public_path);
+    }
 
-        json root = format_embeddings_response_oaicompat(body, data);
+    // using embedded static files
+    auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
+        return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
+            res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
+            return false;
+        };
+    };
 
-        return res.set_content(root.dump(), "application/json; charset=utf-8");
+    svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
+        // TODO @ngxson : I have no idea what it is... maybe this is redundant?
+        return res.set_content("", "application/json; charset=utf-8");
     });
+    svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
+    svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
+    svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
+    svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
+        json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
+
+    // register API routes
+    svr->Get ("/health",              handle_health);
+    svr->Get ("/slots",               handle_slots);
+    svr->Get ("/metrics",             handle_metrics);
+    svr->Get ("/props",               handle_props);
+    svr->Get ("/v1/models",           handle_models);
+    svr->Post("/completion",          handle_completions); // legacy
+    svr->Post("/completions",         handle_completions);
+    svr->Post("/v1/completions",      handle_completions);
+    svr->Post("/chat/completions",    handle_chat_completions);
+    svr->Post("/v1/chat/completions", handle_chat_completions);
+    svr->Post("/infill",              handle_infill);
+    svr->Post("/embedding",           handle_embeddings); // legacy
+    svr->Post("/embeddings",          handle_embeddings);
+    svr->Post("/v1/embeddings",       handle_embeddings);
+    svr->Post("/tokenize",            handle_tokenize);
+    svr->Post("/detokenize",          handle_detokenize);
 
+    //
+    // Start the server
+    //
     if (sparams.n_threads_http < 1) {
         // +2 threads for monitoring endpoints
         sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
index 42a6709a53380e1c54b95601ba3a9f7daea31d1d..1d6aa40ea69856aa7623a698c47f8e55454a102f 100644 (file)
@@ -39,8 +39,9 @@ Feature: Security
 
 
   Scenario Outline: CORS Options
-    When an OPTIONS request is sent from <origin>
-    Then CORS header <cors_header> is set to <cors_header_value>
+    Given a user api key llama.cpp
+    When  an OPTIONS request is sent from <origin>
+    Then  CORS header <cors_header> is set to <cors_header_value>
 
     Examples: Headers
       | origin          | cors_header                      | cors_header_value |
index 0076f805be4d32696335b06e7a84a127e677cb47..14204850960c9a1ad178a0735c285aeb1da28b81 100644 (file)
@@ -582,8 +582,9 @@ async def step_detokenize(context):
 @async_run_until_complete
 async def step_options_request(context, origin):
     async with aiohttp.ClientSession() as session:
+        headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin}
         async with session.options(f'{context.base_url}/v1/chat/completions',
-                                   headers={"Origin": origin}) as response:
+                                    headers=headers) as response:
             assert response.status == 200
             context.options_response = response