]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : replace sleep with condition variables (#4673)
authorJustine Tunney <redacted>
Fri, 29 Dec 2023 14:24:12 +0000 (06:24 -0800)
committerGitHub <redacted>
Fri, 29 Dec 2023 14:24:12 +0000 (16:24 +0200)
The server currently schedules tasks using a sleep(5ms) busy loop. This
adds unnecessary latency since most sleep implementations do a round up
to the system scheduling quantum (usually 10ms). Other libc sleep impls
spin for smaller time intervals which results in the server's busy loop
consuming all available cpu. Having the explicit notify() / wait() code
also helps aid in the readability of the server code.

See mozilla-Ocho/llamafile@711344b

examples/server/server.cpp

index 035eb24ac69324c5d4da4adc5fba3c3caac920e8..0aada8e28029cd9c704b521bee5e0584992f6d47 100644 (file)
@@ -25,6 +25,7 @@
 #include <thread>
 #include <mutex>
 #include <chrono>
+#include <condition_variable>
 
 #ifndef SERVER_VERBOSE
 #define SERVER_VERBOSE 1
@@ -541,7 +542,9 @@ struct llama_server_context
     std::vector<task_result> queue_results;
     std::vector<task_multi>  queue_multitasks;
     std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
+    std::condition_variable condition_tasks;
     std::mutex mutex_results;
+    std::condition_variable condition_results;
 
     ~llama_server_context()
     {
@@ -1169,7 +1172,7 @@ struct llama_server_context
 
     void send_error(task_server& task, std::string error)
     {
-        std::lock_guard<std::mutex> lock(mutex_results);
+        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = task.id;
         res.multitask_id = task.multitask_id;
@@ -1177,6 +1180,7 @@ struct llama_server_context
         res.error = true;
         res.result_json = { { "content", error } };
         queue_results.push_back(res);
+        condition_results.notify_all();
     }
 
     void add_multi_task(int id, std::vector<int>& sub_ids)
@@ -1186,6 +1190,7 @@ struct llama_server_context
         multi.id = id;
         std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
         queue_multitasks.push_back(multi);
+        condition_tasks.notify_one();
     }
 
     void update_multi_task(int multitask_id, int subtask_id, task_result& result)
@@ -1197,6 +1202,7 @@ struct llama_server_context
             {
                 multitask.subtasks_remaining.erase(subtask_id);
                 multitask.results.push_back(result);
+                condition_tasks.notify_one();
             }
         }
     }
@@ -1244,7 +1250,7 @@ struct llama_server_context
 
     void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
     {
-        std::lock_guard<std::mutex> lock(mutex_results);
+        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1280,11 +1286,12 @@ struct llama_server_context
         }
 
         queue_results.push_back(res);
+        condition_results.notify_all();
     }
 
     void send_final_response(llama_client_slot &slot)
     {
-        std::lock_guard<std::mutex> lock(mutex_results);
+        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1340,11 +1347,12 @@ struct llama_server_context
         }
 
         queue_results.push_back(res);
+        condition_results.notify_all();
     }
 
     void send_embedding(llama_client_slot &slot)
     {
-        std::lock_guard<std::mutex> lock(mutex_results);
+        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1372,6 +1380,7 @@ struct llama_server_context
             };
         }
         queue_results.push_back(res);
+        condition_results.notify_all();
     }
 
     int request_completion(json data, bool infill, bool embedding, int multitask_id)
@@ -1395,6 +1404,7 @@ struct llama_server_context
 
         // otherwise, it's a single-prompt task, we actually queue it
         queue_tasks.push_back(task);
+        condition_tasks.notify_one();
         return task.id;
     }
 
@@ -1402,13 +1412,10 @@ struct llama_server_context
     {
         while (true)
         {
-            std::this_thread::sleep_for(std::chrono::microseconds(5));
-            std::lock_guard<std::mutex> lock(mutex_results);
-
-            if (queue_results.empty())
-            {
-                continue;
-            }
+            std::unique_lock<std::mutex> lock(mutex_results);
+            condition_results.wait(lock, [&]{
+                return !queue_results.empty();
+            });
 
             for (int i = 0; i < (int) queue_results.size(); i++)
             {
@@ -1504,12 +1511,13 @@ struct llama_server_context
 
     void request_cancel(int task_id)
     {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
+        std::unique_lock<std::mutex> lock(mutex_tasks);
         task_server task;
         task.id = id_gen++;
         task.type = CANCEL_TASK;
         task.target_id = task_id;
         queue_tasks.push_back(task);
+        condition_tasks.notify_one();
     }
 
     int split_multiprompt_task(task_server& multiprompt_task)
@@ -1535,7 +1543,7 @@ struct llama_server_context
 
     void process_tasks()
     {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
+        std::unique_lock<std::mutex> lock(mutex_tasks);
         while (!queue_tasks.empty())
         {
             task_server task = queue_tasks.front();
@@ -1607,6 +1615,7 @@ struct llama_server_context
 
                 std::lock_guard<std::mutex> lock(mutex_results);
                 queue_results.push_back(aggregate_result);
+                condition_results.notify_all();
 
                 queue_iterator = queue_multitasks.erase(queue_iterator);
             }
@@ -1637,8 +1646,10 @@ struct llama_server_context
                 LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
                 kv_cache_clear();
             }
-            // avoid 100% usage of cpu all time
-            std::this_thread::sleep_for(std::chrono::milliseconds(5));
+            std::unique_lock<std::mutex> lock(mutex_tasks);
+            condition_tasks.wait(lock, [&]{
+                return !queue_tasks.empty();
+            });
         }
 
         for (llama_client_slot &slot : slots)