]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : graceful shutdown, atomic server state, and health endpoint Improvements...
authorSacha Arbonel <redacted>
Mon, 16 Jun 2025 08:14:26 +0000 (10:14 +0200)
committerGitHub <redacted>
Mon, 16 Jun 2025 08:14:26 +0000 (10:14 +0200)
* feat(server): implement graceful shutdown and server state management

* refactor(server): use lambda capture by reference in server.cpp

examples/server/server.cpp

index df5088396927698ea6e8ede9b29e93f1c78dfe65..8b6c5a96720047fcc393ebe5ac09df91400f54d6 100644 (file)
 #include <string>
 #include <thread>
 #include <vector>
+#include <memory>
+#include <csignal>
+#include <atomic>
+#include <functional>
+#include <cstdlib>
+#if defined (_WIN32)
+#include <windows.h>
+#endif
 
 using namespace httplib;
 using json = nlohmann::ordered_json;
 
+enum server_state {
+    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
+    SERVER_STATE_READY,          // Server is ready and model is loaded
+};
+
 namespace {
 
 // output formats
@@ -27,6 +40,20 @@ const std::string srt_format    = "srt";
 const std::string vjson_format  = "verbose_json";
 const std::string vtt_format    = "vtt";
 
+std::function<void(int)> shutdown_handler;
+std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+inline void signal_handler(int signal) {
+    if (is_terminating.test_and_set()) {
+        // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
+        // this is for better developer experience, we can remove when the server is stable enough
+        fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+        exit(1);
+    }
+
+    shutdown_handler(signal);
+}
+
 struct server_params
 {
     std::string hostname = "127.0.0.1";
@@ -654,6 +681,9 @@ int main(int argc, char ** argv) {
         }
     }
 
+    std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
+    std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
+
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
@@ -663,9 +693,10 @@ int main(int argc, char ** argv) {
 
     // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
     whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
+    state.store(SERVER_STATE_READY);
+
 
-    Server svr;
-    svr.set_default_headers({{"Server", "whisper.cpp"},
+    svr->set_default_headers({{"Server", "whisper.cpp"},
                              {"Access-Control-Allow-Origin", "*"},
                              {"Access-Control-Allow-Headers", "content-type, authorization"}});
 
@@ -744,15 +775,15 @@ int main(int argc, char ** argv) {
     whisper_params default_params = params;
 
     // this is only called if no index.html is found in the public --path
-    svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
+    svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){
         res.set_content(default_content, "text/html");
         return false;
     });
 
-    svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
+    svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
     });
 
-    svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
+    svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
         // acquire whisper model mutex lock
         std::lock_guard<std::mutex> lock(whisper_mutex);
 
@@ -1068,8 +1099,9 @@ int main(int argc, char ** argv) {
         // reset params to their defaults
         params = default_params;
     });
-    svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
+    svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
         std::lock_guard<std::mutex> lock(whisper_mutex);
+        state.store(SERVER_STATE_LOADING_MODEL);
         if (!req.has_file("model"))
         {
             fprintf(stderr, "error: no 'model' field in the request\n");
@@ -1101,18 +1133,25 @@ int main(int argc, char ** argv) {
         // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
         whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
 
+        state.store(SERVER_STATE_READY);
         const std::string success = "Load was successful!";
         res.set_content(success, "application/text");
 
         // check if the model is in the file system
     });
 
-    svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
-        const std::string health_response = "{\"status\":\"ok\"}";
-        res.set_content(health_response, "application/json");
+    svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){
+        server_state current_state = state.load();
+        if (current_state == SERVER_STATE_READY) {
+            const std::string health_response = "{\"status\":\"ok\"}";
+            res.set_content(health_response, "application/json");
+        } else {
+            res.set_content("{\"status\":\"loading model\"}", "application/json");
+            res.status = 503;
+        }
     });
 
-    svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
+    svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
         const char fmt[] = "500 Internal Server Error\n%s";
         char buf[BUFSIZ];
         try {
@@ -1126,7 +1165,7 @@ int main(int argc, char ** argv) {
         res.status = 500;
     });
 
-    svr.set_error_handler([](const Request &req, Response &res) {
+    svr->set_error_handler([](const Request &req, Response &res) {
         if (res.status == 400) {
             res.set_content("Invalid request", "text/plain");
         } else if (res.status != 500) {
@@ -1136,10 +1175,10 @@ int main(int argc, char ** argv) {
     });
 
     // set timeouts and change hostname and port
-    svr.set_read_timeout(sparams.read_timeout);
-    svr.set_write_timeout(sparams.write_timeout);
+    svr->set_read_timeout(sparams.read_timeout);
+    svr->set_write_timeout(sparams.write_timeout);
 
-    if (!svr.bind_to_port(sparams.hostname, sparams.port))
+    if (!svr->bind_to_port(sparams.hostname, sparams.port))
     {
         fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
                 sparams.hostname.c_str(), sparams.port);
@@ -1147,18 +1186,50 @@ int main(int argc, char ** argv) {
     }
 
     // Set the base directory for serving static files
-    svr.set_base_dir(sparams.public_path);
+    svr->set_base_dir(sparams.public_path);
 
     // to make it ctrl+clickable:
     printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
 
-    if (!svr.listen_after_bind())
-    {
-        return 1;
-    }
+    shutdown_handler = [&](int signal) {
+        printf("\nCaught signal %d, shutting down gracefully...\n", signal);
+        if (svr) {
+            svr->stop();
+        }
+    };
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+    struct sigaction sigint_action;
+    sigint_action.sa_handler = signal_handler;
+    sigemptyset (&sigint_action.sa_mask);
+    sigint_action.sa_flags = 0;
+    sigaction(SIGINT, &sigint_action, NULL);
+    sigaction(SIGTERM, &sigint_action, NULL);
+#elif defined (_WIN32)
+    auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+        return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
+    };
+    SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
+#endif
+
+    // clean up function, to be called before exit
+    auto clean_up = [&]() {
+        whisper_print_timings(ctx);
+        whisper_free(ctx);
+    };
+
+    std::thread t([&] {
+        if (!svr->listen_after_bind()) {
+            fprintf(stderr, "error: server listen failed\n");
+        }
+    });
+
+    svr->wait_until_ready();
+
+    t.join();
+
 
-    whisper_print_timings(ctx);
-    whisper_free(ctx);
+    clean_up();
 
     return 0;
 }