]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : clean-up completed tasks from waiting list (#9531)
authorGeorgi Gerganov <redacted>
Thu, 19 Sep 2024 09:44:53 +0000 (12:44 +0300)
committerGitHub <redacted>
Thu, 19 Sep 2024 09:44:53 +0000 (12:44 +0300)
ggml-ci

examples/server/server.cpp

index dce69f832e8bd1ea5915bd159c7598621b4e24d2..0ca9999940606d91f91bfb1353afc67630bdc8d2 100644 (file)
@@ -531,26 +531,38 @@ struct server_response {
 
     // add the id_task to the list of tasks waiting for response
     void add_waiting_task_id(int id_task) {
-        SRV_DBG("waiting for task id = %d\n", id_task);
+        SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
 
         std::unique_lock<std::mutex> lock(mutex_results);
         waiting_task_ids.insert(id_task);
     }
 
     void add_waiting_tasks(const std::vector<server_task> & tasks) {
-        for (const auto & t : tasks) {
-            add_waiting_task_id(t.id);
+        std::unique_lock<std::mutex> lock(mutex_results);
+
+        for (const auto & task : tasks) {
+            SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
+            waiting_task_ids.insert(task.id);
         }
     }
 
     // when the request is finished, we can remove task associated with it
     void remove_waiting_task_id(int id_task) {
-        SRV_DBG("task id = %d is done\n", id_task);
+        SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
 
         std::unique_lock<std::mutex> lock(mutex_results);
         waiting_task_ids.erase(id_task);
     }
 
+    void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+
+        for (const auto & id_task : id_tasks) {
+            SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+            waiting_task_ids.erase(id_task);
+        }
+    }
+
     // This function blocks the thread until there is a response for one of the id_tasks
     server_task_result recv(const std::unordered_set<int> & id_tasks) {
         while (true) {
@@ -2774,6 +2786,8 @@ int main(int argc, char ** argv) {
             }, [&](const json & error_data) {
                 res_error(res, error_data);
             });
+
+            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         } else {
             const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
                 ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2784,7 +2798,12 @@ int main(int argc, char ** argv) {
                 sink.done();
                 return false;
             };
-            res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
+
+            auto on_complete = [task_ids, &ctx_server] (bool) {
+                ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+            };
+
+            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
         }
     };
 
@@ -2823,6 +2842,8 @@ int main(int argc, char ** argv) {
             }, [&](const json & error_data) {
                 res_error(res, error_data);
             });
+
+            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         } else {
             const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
                 ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2844,7 +2865,12 @@ int main(int argc, char ** argv) {
                 sink.done();
                 return true;
             };
-            res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
+
+            auto on_complete = [task_ids, &ctx_server] (bool) {
+                ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+            };
+
+            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
         }
     };
 
@@ -2953,6 +2979,8 @@ int main(int argc, char ** argv) {
                 res_error(res, error_data);
                 error = true;
             });
+
+            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         }
 
         if (error) {