]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : refactor middleware and /health endpoint (#9056)
authorXuan Son Nguyen <redacted>
Fri, 16 Aug 2024 15:19:05 +0000 (17:19 +0200)
committerGitHub <redacted>
Fri, 16 Aug 2024 15:19:05 +0000 (17:19 +0200)
* server : refactor middleware and /health endpoint

* move "fail_on_no_slot" to /slots

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <redacted>
* fix server tests

* fix CI

* update server docs

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/steps/steps.py

index e17595fe87f2549eede6114ee0a9623b25fdc627..930ae15f64d8b685cc66aaf5b8160e76f5dd3279 100644 (file)
@@ -368,15 +368,16 @@ node index.js
 
 ## API Endpoints
 
-### GET `/health`: Returns the current state of the server
+### GET `/health`: Returns heath check result
 
-  - 503 -> `{"status": "loading model"}` if the model is still being loaded.
-  - 500 -> `{"status": "error"}` if the model failed to load.
-  - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
-  - 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available.
-  - 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available.
+**Response format**
 
-  If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set.
+- HTTP status code 503
+  - Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}`
+  - Explanation: the model is still being loaded.
+- HTTP status code 200
+  - Body: `{"status": "ok" }`
+  - Explanation: the model is successfully loaded and the server is ready.
 
 ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
 
@@ -639,10 +640,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
     }'
     ```
 
-### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
+### GET `/slots`: Returns the current slots processing state
+
+This endpoint can be disabled with `--no-slots`
+
+If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots.
 
 **Response format**
 
+Example:
+
 ```json
 [
     {
@@ -702,7 +709,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
 ]
 ```
 
-### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled:
+Possible values for `slot[i].state` are:
+- `0`: SLOT_STATE_IDLE
+- `1`: SLOT_STATE_PROCESSING
+
+### GET `/metrics`: Prometheus compatible metrics exporter
+
+This endpoint is only accessible if `--metrics` is set.
 
 Available metrics:
 - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
@@ -767,6 +780,10 @@ Available metrics:
 
 ### GET `/lora-adapters`: Get list of all LoRA adapters
 
+This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...`
+
+By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
+
 If an adapter is disabled, the scale will be set to 0.
 
 **Response format**
index e073f5813d459f0a5f60548fde0e0871bc4072f9..ce711eadd29acfb678cac35c96a8a6ff23b962d9 100644 (file)
@@ -15,6 +15,8 @@
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
+// mime type for sending response
+#define MIMETYPE_JSON "application/json; charset=utf-8"
 
 // auto generated files (update with ./deps.sh)
 #include "colorthemes.css.hpp"
@@ -67,7 +69,6 @@ enum slot_command {
 enum server_state {
     SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
     SERVER_STATE_READY,          // Server is ready and model is loaded
-    SERVER_STATE_ERROR           // An error occurred, load_model failed
 };
 
 enum server_task_type {
@@ -695,6 +696,7 @@ struct server_context {
 
         add_bos_token = llama_add_bos_token(model);
         has_eos_token = !llama_add_eos_token(model);
+
         return true;
     }
 
@@ -2555,19 +2557,19 @@ int main(int argc, char ** argv) {
     svr->set_default_headers({{"Server", "llama.cpp"}});
 
     // CORS preflight
-    svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin",      req.get_header_value("Origin"));
+    svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
+        // Access-Control-Allow-Origin is already set by middleware
         res.set_header("Access-Control-Allow-Credentials", "true");
         res.set_header("Access-Control-Allow-Methods",     "POST");
         res.set_header("Access-Control-Allow-Headers",     "*");
-        return res.set_content("", "application/json; charset=utf-8");
+        return res.set_content("", "text/html"); // blank response, no data
     });
 
     svr->set_logger(log_server_request);
 
     auto res_error = [](httplib::Response & res, json error_data) {
         json final_response {{"error", error_data}};
-        res.set_content(final_response.dump(), "application/json; charset=utf-8");
+        res.set_content(final_response.dump(), MIMETYPE_JSON);
         res.status = json_value(error_data, "code", 500);
     };
 
@@ -2597,11 +2599,6 @@ int main(int argc, char ** argv) {
     svr->set_read_timeout (params.timeout_read);
     svr->set_write_timeout(params.timeout_write);
 
-    if (!svr->bind_to_port(params.hostname, params.port)) {
-        fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
-        return 1;
-    }
-
     std::unordered_map<std::string, std::string> log_data;
 
     log_data["hostname"] = params.hostname;
@@ -2617,35 +2614,6 @@ int main(int argc, char ** argv) {
     // Necessary similarity of prompt for slot selection
     ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
 
-    // load the model
-    if (!ctx_server.load_model(params)) {
-        state.store(SERVER_STATE_ERROR);
-        return 1;
-    } else {
-        ctx_server.init();
-        state.store(SERVER_STATE_READY);
-    }
-
-    LOG_INFO("model loaded", {});
-
-    const auto model_meta = ctx_server.model_meta();
-
-    // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
-    if (params.chat_template.empty()) {
-        if (!ctx_server.validate_model_chat_template()) {
-            LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
-            params.chat_template = "chatml";
-        }
-    }
-
-    // print sample chat example to make it clear which template is used
-    {
-        LOG_INFO("chat template", {
-            {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
-            {"built_in",     params.chat_template.empty()},
-        });
-    }
-
     //
     // Middlewares
     //
@@ -2689,8 +2657,6 @@ int main(int argc, char ** argv) {
         }
 
         // API key is invalid or not provided
-        // TODO: make another middleware for CORS related logic
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
 
         LOG_WARNING("Unauthorized: Invalid API Key", {});
@@ -2698,8 +2664,21 @@ int main(int argc, char ** argv) {
         return false;
     };
 
+    auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) {
+        server_state current_state = state.load();
+        if (current_state == SERVER_STATE_LOADING_MODEL) {
+            res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
+            return false;
+        }
+        return true;
+    };
+
     // register server middlewares
-    svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
+    svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+        if (!middleware_server_state(req, res)) {
+            return httplib::Server::HandlerResponse::Handled;
+        }
         if (!middleware_validate_api_key(req, res)) {
             return httplib::Server::HandlerResponse::Handled;
         }
@@ -2710,62 +2689,15 @@ int main(int argc, char ** argv) {
     // Route handlers (or controllers)
     //
 
-    const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
-        server_state current_state = state.load();
-        switch (current_state) {
-            case SERVER_STATE_READY:
-                {
-                    // request slots data using task queue
-                    server_task task;
-                    task.id   = ctx_server.queue_tasks.get_new_id();
-                    task.type = SERVER_TASK_TYPE_METRICS;
-                    task.id_target = -1;
-
-                    ctx_server.queue_results.add_waiting_task_id(task.id);
-                    ctx_server.queue_tasks.post(task);
-
-                    // get the result
-                    server_task_result result = ctx_server.queue_results.recv(task.id);
-                    ctx_server.queue_results.remove_waiting_task_id(task.id);
-
-                    const int n_idle_slots       = result.data.at("idle");
-                    const int n_processing_slots = result.data.at("processing");
-
-                    json health = {
-                        {"status",           "ok"},
-                        {"slots_idle",       n_idle_slots},
-                        {"slots_processing", n_processing_slots}
-                    };
-
-                    res.status = 200; // HTTP OK
-                    if (params.endpoint_slots && req.has_param("include_slots")) {
-                        health["slots"] = result.data.at("slots");
-                    }
-
-                    if (n_idle_slots == 0) {
-                        health["status"] = "no slot available";
-                        if (req.has_param("fail_on_no_slot")) {
-                            res.status = 503; // HTTP Service Unavailable
-                        }
-                    }
-
-                    res.set_content(health.dump(), "application/json");
-                    break;
-                }
-            case SERVER_STATE_LOADING_MODEL:
-                {
-                    res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
-                } break;
-            case SERVER_STATE_ERROR:
-                {
-                    res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
-                } break;
-        }
+    const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
+        // error and loading states are handled by middleware
+        json health = {{"status", "ok"}};
+        res.set_content(health.dump(), "application/json");
     };
 
-    const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
+    const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
         if (!params.endpoint_slots) {
-            res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
+            res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
@@ -2783,13 +2715,22 @@ int main(int argc, char ** argv) {
         server_task_result result = ctx_server.queue_results.recv(task.id);
         ctx_server.queue_results.remove_waiting_task_id(task.id);
 
-        res.set_content(result.data.at("slots").dump(), "application/json");
+        // optionally return "fail_on_no_slot" error
+        const int n_idle_slots = result.data.at("idle");
+        if (req.has_param("fail_on_no_slot")) {
+            if (n_idle_slots == 0) {
+                res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
+                return;
+            }
+        }
+
+        res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
         res.status = 200; // HTTP OK
     };
 
     const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
         if (!params.endpoint_metrics) {
-            res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
+            res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
@@ -2914,7 +2855,7 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), "application/json");
+            res.set_content(result.data.dump(), MIMETYPE_JSON);
         }
     };
 
@@ -2944,7 +2885,7 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), "application/json");
+            res.set_content(result.data.dump(), MIMETYPE_JSON);
         }
     };
 
@@ -2964,13 +2905,11 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), "application/json");
+            res.set_content(result.data.dump(), MIMETYPE_JSON);
         }
     };
 
     const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
         std::string id_slot_str = req.path_params.at("id_slot");
         int id_slot;
 
@@ -2994,7 +2933,7 @@ int main(int argc, char ** argv) {
         }
     };
 
-    const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
         std::string template_key = "tokenizer.chat_template", curr_tmpl;
         int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
         if (tlen > 0) {
@@ -3003,7 +2942,6 @@ int main(int argc, char ** argv) {
                 curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
             }
         }
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = {
             { "system_prompt",               ctx_server.system_prompt.c_str() },
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
@@ -3011,7 +2949,7 @@ int main(int argc, char ** argv) {
             { "chat_template",               curr_tmpl.c_str() }
         };
 
-        res.set_content(data.dump(), "application/json; charset=utf-8");
+        res.set_content(data.dump(), MIMETYPE_JSON);
     };
 
     const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
@@ -3020,8 +2958,6 @@ int main(int argc, char ** argv) {
             return;
         }
 
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
         json data = json::parse(req.body);
 
         const int id_task = ctx_server.queue_tasks.get_new_id();
@@ -3032,7 +2968,7 @@ int main(int argc, char ** argv) {
         if (!json_value(data, "stream", false)) {
             server_task_result result = ctx_server.queue_results.recv(id_task);
             if (!result.error && result.stop) {
-                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
+                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
             } else {
                 res_error(res, result.data);
             }
@@ -3095,9 +3031,7 @@ int main(int argc, char ** argv) {
         }
     };
 
-    const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
+    const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
         json models = {
             {"object", "list"},
             {"data", {
@@ -3106,12 +3040,12 @@ int main(int argc, char ** argv) {
                      {"object",   "model"},
                      {"created",  std::time(0)},
                      {"owned_by", "llamacpp"},
-                     {"meta",     model_meta}
+                     {"meta",     ctx_server.model_meta()}
                  },
              }}
         };
 
-        res.set_content(models.dump(), "application/json; charset=utf-8");
+        res.set_content(models.dump(), MIMETYPE_JSON);
     };
 
     const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
@@ -3119,8 +3053,6 @@ int main(int argc, char ** argv) {
             res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
-
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
 
         const int id_task = ctx_server.queue_tasks.get_new_id();
@@ -3135,7 +3067,7 @@ int main(int argc, char ** argv) {
             if (!result.error && result.stop) {
                 json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
 
-                res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
+                res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
             } else {
                 res_error(res, result.data);
             }
@@ -3197,8 +3129,6 @@ int main(int argc, char ** argv) {
             return;
         }
 
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
         json data = json::parse(req.body);
 
         const int id_task = ctx_server.queue_tasks.get_new_id();
@@ -3209,7 +3139,7 @@ int main(int argc, char ** argv) {
         if (!json_value(data, "stream", false)) {
             server_task_result result = ctx_server.queue_results.recv(id_task);
             if (!result.error && result.stop) {
-                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
+                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
             } else {
                 res_error(res, result.data);
             }
@@ -3257,7 +3187,6 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
         std::vector<llama_token> tokens;
@@ -3266,11 +3195,10 @@ int main(int argc, char ** argv) {
             tokens = ctx_server.tokenize(body.at("content"), add_special);
         }
         const json data = format_tokenizer_response(tokens);
-        return res.set_content(data.dump(), "application/json; charset=utf-8");
+        return res.set_content(data.dump(), MIMETYPE_JSON);
     };
 
     const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         const json body = json::parse(req.body);
 
         std::string content;
@@ -3280,12 +3208,10 @@ int main(int argc, char ** argv) {
         }
 
         const json data = format_detokenized_response(content);
-        return res.set_content(data.dump(), "application/json; charset=utf-8");
+        return res.set_content(data.dump(), MIMETYPE_JSON);
     };
 
     const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
         const json body = json::parse(req.body);
         bool is_openai = false;
 
@@ -3331,11 +3257,10 @@ int main(int argc, char ** argv) {
         json root = is_openai
             ? format_embeddings_response_oaicompat(body, responses)
             : responses[0];
-        return res.set_content(root.dump(), "application/json; charset=utf-8");
+        return res.set_content(root.dump(), MIMETYPE_JSON);
     };
 
-    const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+    const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
         json result = json::array();
         for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
             auto & la = ctx_server.lora_adapters[i];
@@ -3345,13 +3270,11 @@ int main(int argc, char ** argv) {
                 {"scale", la.scale},
             });
         }
-        res.set_content(result.dump(), "application/json");
+        res.set_content(result.dump(), MIMETYPE_JSON);
         res.status = 200; // HTTP OK
     };
 
     const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
-        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-
         const std::vector<json> body = json::parse(req.body);
         int max_idx = ctx_server.lora_adapters.size();
 
@@ -3379,7 +3302,7 @@ int main(int argc, char ** argv) {
         server_task_result result = ctx_server.queue_results.recv(id_task);
         ctx_server.queue_results.remove_waiting_task_id(id_task);
 
-        res.set_content(result.data.dump(), "application/json");
+        res.set_content(result.data.dump(), MIMETYPE_JSON);
         res.status = 200; // HTTP OK
     };
 
@@ -3455,35 +3378,75 @@ int main(int argc, char ** argv) {
     log_data["n_threads_http"] =  std::to_string(params.n_threads_http);
     svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
 
-    LOG_INFO("HTTP server listening", log_data);
+    // clean up function, to be called before exit
+    auto clean_up = [&svr]() {
+        svr->stop();
+        llama_backend_free();
+    };
 
-    // run the HTTP server in a thread - see comment below
-    std::thread t([&]() {
-        if (!svr->listen_after_bind()) {
-            state.store(SERVER_STATE_ERROR);
-            return 1;
+    // bind HTTP listen port, run the HTTP server in a thread
+    if (!svr->bind_to_port(params.hostname, params.port)) {
+        LOG_ERROR("couldn't bind HTTP server socket", {
+            {"hostname", params.hostname},
+            {"port", params.port},
+        });
+        clean_up();
+        LOG_ERROR("exiting due to HTTP server error", {});
+        return 1;
+    }
+    std::thread t([&]() { svr->listen_after_bind(); });
+    svr->wait_until_ready();
+
+    LOG_INFO("HTTP server is listening", log_data);
+
+    // load the model
+    LOG_INFO("loading model", log_data);
+    if (!ctx_server.load_model(params)) {
+        clean_up();
+        t.join();
+        LOG_ERROR("exiting due to model loading error", {});
+        return 1;
+    } else {
+        ctx_server.init();
+        state.store(SERVER_STATE_READY);
+
+        LOG_INFO("model loaded", {});
+
+        // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
+        if (params.chat_template.empty()) {
+            if (!ctx_server.validate_model_chat_template()) {
+                LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
+                params.chat_template = "chatml";
+            }
         }
 
-        return 0;
-    });
+        // print sample chat example to make it clear which template is used
+        {
+            LOG_INFO("chat template", {
+                {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
+                {"built_in",     params.chat_template.empty()},
+            });
+        }
 
-    ctx_server.queue_tasks.on_new_task(std::bind(
-        &server_context::process_single_task, &ctx_server, std::placeholders::_1));
-    ctx_server.queue_tasks.on_finish_multitask(std::bind(
-        &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
-    ctx_server.queue_tasks.on_update_slots(std::bind(
-        &server_context::update_slots, &ctx_server));
-    ctx_server.queue_results.on_multitask_update(std::bind(
-        &server_queue::update_multitask,
-        &ctx_server.queue_tasks,
-        std::placeholders::_1,
-        std::placeholders::_2,
-        std::placeholders::_3
-    ));
-
-    shutdown_handler = [&](int) {
-        ctx_server.queue_tasks.terminate();
-    };
+        ctx_server.queue_tasks.on_new_task(std::bind(
+            &server_context::process_single_task, &ctx_server, std::placeholders::_1));
+        ctx_server.queue_tasks.on_finish_multitask(std::bind(
+            &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
+        ctx_server.queue_tasks.on_update_slots(std::bind(
+            &server_context::update_slots, &ctx_server));
+        ctx_server.queue_results.on_multitask_update(std::bind(
+            &server_queue::update_multitask,
+            &ctx_server.queue_tasks,
+            std::placeholders::_1,
+            std::placeholders::_2,
+            std::placeholders::_3
+        ));
+
+        shutdown_handler = [&](int) {
+            ctx_server.queue_tasks.terminate();
+        };
+        ctx_server.queue_tasks.start_loop();
+    }
 
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
     struct sigaction sigint_action;
@@ -3499,12 +3462,8 @@ int main(int argc, char ** argv) {
     SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
 #endif
 
-    ctx_server.queue_tasks.start_loop();
-
-    svr->stop();
+    clean_up();
     t.join();
 
-    llama_backend_free();
-
     return 0;
 }
index 6705a34fc469650b64f3a4b911af9adc355db644..1ba7b60b69c46f79dcb88df6f12e3556c9887895 100644 (file)
@@ -205,27 +205,20 @@ def step_start_server(context):
 async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
     match expecting_status:
         case 'healthy':
-            await wait_for_health_status(context, context.base_url, 200, 'ok',
-                                         timeout=30)
+            await wait_for_slots_status(context, context.base_url, 200,
+                                        timeout=30)
 
         case 'ready' | 'idle':
-            await wait_for_health_status(context, context.base_url, 200, 'ok',
-                                         timeout=30,
-                                         params={'fail_on_no_slot': 0, 'include_slots': 0},
-                                         slots_idle=context.n_slots,
-                                         slots_processing=0,
-                                         expected_slots=[{'id': slot_id, 'state': 0}
-                                                         for slot_id in
-                                                         range(context.n_slots if context.n_slots else 1)])
+            await wait_for_slots_status(context, context.base_url, 200,
+                                        timeout=30,
+                                        params={'fail_on_no_slot': 1},
+                                        slots_idle=context.n_slots,
+                                        slots_processing=0)
         case 'busy':
-            await wait_for_health_status(context, context.base_url, 503,
-                                         'no slot available',
-                                         params={'fail_on_no_slot': 0, 'include_slots': 0},
-                                         slots_idle=0,
-                                         slots_processing=context.n_slots,
-                                         expected_slots=[{'id': slot_id, 'state': 1}
-                                                         for slot_id in
-                                                         range(context.n_slots if context.n_slots else 1)])
+            await wait_for_slots_status(context, context.base_url, 503,
+                                        params={'fail_on_no_slot': 1},
+                                        slots_idle=0,
+                                        slots_processing=context.n_slots)
         case _:
             assert False, "unknown status"
 
@@ -1187,17 +1180,15 @@ async def gather_tasks_results(context):
     return n_completions
 
 
-async def wait_for_health_status(context,
-                                 base_url,
-                                 expected_http_status_code,
-                                 expected_health_status,
-                                 timeout=3,
-                                 params=None,
-                                 slots_idle=None,
-                                 slots_processing=None,
-                                 expected_slots=None):
+async def wait_for_slots_status(context,
+                                base_url,
+                                expected_http_status_code,
+                                timeout=3,
+                                params=None,
+                                slots_idle=None,
+                                slots_processing=None):
     if context.debug:
-        print(f"Starting checking for health for expected_health_status={expected_health_status}")
+        print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}")
     interval = 0.5
     counter = 0
     if 'GITHUB_ACTIONS' in os.environ:
@@ -1205,26 +1196,19 @@ async def wait_for_health_status(context,
 
     async with aiohttp.ClientSession() as session:
         while True:
-            async with await session.get(f'{base_url}/health', params=params) as health_response:
-                status_code = health_response.status
-                health = await health_response.json()
+            async with await session.get(f'{base_url}/slots', params=params) as slots_response:
+                status_code = slots_response.status
+                slots = await slots_response.json()
                 if context.debug:
-                    print(f"HEALTH - response for expected health status='{expected_health_status}' on "
-                          f"'{base_url}/health'?{params} is {health}\n")
-                if (status_code == expected_http_status_code
-                        and health['status'] == expected_health_status
-                        and (slots_idle is None or health['slots_idle'] == slots_idle)
-                        and (slots_processing is None or health['slots_processing'] == slots_processing)):
-                    if expected_slots is not None:
-                        assert_slots_status(health['slots'], expected_slots)
-                    return
-                if (status_code == expected_http_status_code
-                        and health['status'] == expected_health_status
-                        and (slots_idle is None or health['slots_idle'] == slots_idle)
-                        and (slots_processing is None or health['slots_processing'] == slots_processing)):
-                    if expected_slots is not None:
-                        assert_slots_status(health['slots'], expected_slots)
+                    print(f"slots responses {slots}\n")
+                if status_code == 503 and status_code == expected_http_status_code:
                     return
+                if status_code == 200 and status_code == expected_http_status_code:
+                    n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots)
+                    n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots)
+                    if ((slots_idle is None or slots_idle == n_slots_idle)
+                        and (slots_processing is None or slots_processing == n_slots_processing)):
+                        return
             await asyncio.sleep(interval)
 
             counter += interval
@@ -1238,7 +1222,7 @@ async def wait_for_health_status(context,
                         if n_completions > 0:
                             return
 
-                assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}'
+                assert False, f'slots check timeout exceeded {counter}s>={timeout}'
 
 
 def assert_embeddings(embeddings):