]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add single-client multi-prompt support (#4232)
authorZiad Ben Hadj-Alouane <redacted>
Thu, 30 Nov 2023 22:25:04 +0000 (17:25 -0500)
committerGitHub <redacted>
Thu, 30 Nov 2023 22:25:04 +0000 (00:25 +0200)
* * add multiprompt support

* * cleanup

* * more cleanup

* * remove atomicity of id_gen, and change lock_guard to unique_lock on completion requests

* * remove all references to mutex_multitasks

* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <redacted>
* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <redacted>
* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <redacted>
* Update examples/server/server.cpp

Co-authored-by: Jared Van Bortel <redacted>
* * change to set

---------

Co-authored-by: Jared Van Bortel <redacted>
examples/server/server.cpp

index 50f124b13e8492336efa5e7bdf755f736f5505b1..5edb3678efe09283eacca729c152f601dc91fcd9 100644 (file)
@@ -155,15 +155,23 @@ struct task_server {
     json data;
     bool infill_mode = false;
     bool embedding_mode = false;
+    int multitask_id = -1;
 };
 
 struct task_result {
     int id;
+    int multitask_id = -1;
     bool stop;
     bool error;
     json result_json;
 };
 
+struct task_multi {
+    int id;
+    std::set<int> subtasks_remaining{};
+    std::vector<task_result> results{};
+};
+
 // TODO: can become bool if we can't find use of more states
 enum slot_state
 {
@@ -406,6 +414,9 @@ struct llama_client_slot
     double t_prompt_processing; // ms
     double t_token_generation; // ms
 
+    // multitasks
+    int multitask_id = -1;
+
     void reset() {
         num_prompt_tokens      = 0;
         generated_text         = "";
@@ -529,7 +540,8 @@ struct llama_server_context
 
     std::vector<task_server> queue_tasks;
     std::vector<task_result> queue_results;
-    std::mutex mutex_tasks;
+    std::vector<task_multi>  queue_multitasks;
+    std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
     std::mutex mutex_results;
 
     ~llama_server_context()
@@ -1112,17 +1124,40 @@ struct llama_server_context
         return slot.images.size() > 0;
     }
 
-    void send_error(int id, std::string error)
+    void send_error(task_server& task, std::string error)
     {
         std::lock_guard<std::mutex> lock(mutex_results);
         task_result res;
-        res.id = id;
+        res.id = task.id;
+        res.multitask_id = task.multitask_id;
         res.stop = false;
         res.error = true;
         res.result_json = { { "content", error } };
         queue_results.push_back(res);
     }
 
+    void add_multi_task(int id, std::vector<int>& sub_ids)
+    {
+        std::lock_guard<std::mutex> lock(mutex_tasks);
+        task_multi multi;
+        multi.id = id;
+        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
+        queue_multitasks.push_back(multi);
+    }
+
+    void update_multi_task(int multitask_id, int subtask_id, task_result& result)
+    {
+        std::lock_guard<std::mutex> lock(mutex_tasks);
+        for (auto& multitask : queue_multitasks)
+        {
+            if (multitask.id == multitask_id)
+            {
+                multitask.subtasks_remaining.erase(subtask_id);
+                multitask.results.push_back(result);
+            }
+        }
+    }
+
     json get_model_props()
     {
         return get_formated_generation(slots[0]);
@@ -1167,6 +1202,7 @@ struct llama_server_context
         std::lock_guard<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
+        res.multitask_id = slot.multitask_id;
         res.error = false;
         res.stop = false;
 
@@ -1206,6 +1242,7 @@ struct llama_server_context
         std::lock_guard<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
+        res.multitask_id = slot.multitask_id;
         res.error = false;
         res.stop = true;
 
@@ -1251,6 +1288,12 @@ struct llama_server_context
             res.result_json["model"] = slot.oaicompat_model;
         }
 
+        // parent multitask, if any, needs to be updated
+        if (slot.multitask_id != -1)
+        {
+            update_multi_task(slot.multitask_id, slot.task_id, res);
+        }
+
         queue_results.push_back(res);
     }
 
@@ -1259,6 +1302,7 @@ struct llama_server_context
         std::lock_guard<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
+        res.multitask_id = slot.multitask_id;
         res.error = false;
         res.stop = true;
 
@@ -1285,9 +1329,9 @@ struct llama_server_context
         queue_results.push_back(res);
     }
 
-    int request_completion(json data, bool infill, bool embedding)
+    int request_completion(json data, bool infill, bool embedding, int multitask_id)
     {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
+        std::unique_lock<std::mutex> lock(mutex_tasks);
         task_server task;
         task.id = id_gen++;
         task.target_id = 0;
@@ -1295,6 +1339,16 @@ struct llama_server_context
         task.infill_mode = infill;
         task.embedding_mode = embedding;
         task.type = COMPLETION_TASK;
+        task.multitask_id = multitask_id;
+
+        // when a completion task's prompt array is not a singleton, we split it into multiple requests
+        if (task.data.at("prompt").size() > 1)
+        {
+            lock.unlock(); // entering new func scope
+            return split_multiprompt_task(task);
+        }
+
+        // otherwise, it's a single-prompt task, we actually queue it
         queue_tasks.push_back(task);
         return task.id;
     }
@@ -1313,8 +1367,17 @@ struct llama_server_context
 
             for (int i = 0; i < (int) queue_results.size(); i++)
             {
+                // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
+                if (queue_results[i].multitask_id == task_id)
+                {
+                    update_multi_task(task_id, queue_results[i].id, queue_results[i]);
+                    queue_results.erase(queue_results.begin() + i);
+                    continue;
+                }
+
                 if (queue_results[i].id == task_id)
                 {
+                    assert(queue_results[i].multitask_id == -1);
                     task_result res = queue_results[i];
                     queue_results.erase(queue_results.begin() + i);
                     return res;
@@ -1404,6 +1467,27 @@ struct llama_server_context
         queue_tasks.push_back(task);
     }
 
+    int split_multiprompt_task(task_server& multiprompt_task)
+    {
+        auto prompt_count = multiprompt_task.data.at("prompt").size();
+        assert(prompt_count > 1);
+
+        int multitask_id = id_gen++;
+        std::vector<int> subtask_ids(prompt_count);
+        for (int i = 0; i < prompt_count; i++)
+        {
+            json subtask_data = multiprompt_task.data;
+            subtask_data["prompt"] = subtask_data["prompt"][i];
+
+            // subtasks inherit everything else (infill mode, embedding mode, etc.)
+            subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
+        }
+
+        // queue up the multitask so we can track its subtask progression
+        add_multi_task(multitask_id, subtask_ids);
+        return multitask_id;
+    }
+
     void process_tasks()
     {
         std::lock_guard<std::mutex> lock(mutex_tasks);
@@ -1419,7 +1503,7 @@ struct llama_server_context
                     {
                         LOG_TEE("slot unavailable\n");
                         // send error result
-                        send_error(task.id, "slot unavailable");
+                        send_error(task, "slot unavailable");
                         return;
                     }
 
@@ -1433,11 +1517,12 @@ struct llama_server_context
                     slot->infill = task.infill_mode;
                     slot->embedding = task.embedding_mode;
                     slot->task_id = task.id;
+                    slot->multitask_id = task.multitask_id;
 
                     if (!launch_slot_with_data(slot, task.data))
                     {
                         // send error result
-                        send_error(task.id, "internal_error");
+                        send_error(task, "internal_error");
                         break;
                     }
                 } break;
@@ -1453,6 +1538,38 @@ struct llama_server_context
                 } break;
             }
         }
+
+        // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
+        auto queue_iterator = queue_multitasks.begin();
+        while (queue_iterator != queue_multitasks.end())
+        {
+            if (queue_iterator->subtasks_remaining.empty())
+            {
+                // all subtasks done == multitask is done
+                task_result aggregate_result;
+                aggregate_result.id = queue_iterator->id;
+                aggregate_result.stop = true;
+                aggregate_result.error = false;
+
+                // collect json results into one json result
+                std::vector<json> result_jsons;
+                for (auto& subres : queue_iterator->results)
+                {
+                    result_jsons.push_back(subres.result_json);
+                    aggregate_result.error = aggregate_result.error && subres.error;
+                }
+                aggregate_result.result_json = json{ "results", result_jsons };
+
+                std::lock_guard<std::mutex> lock(mutex_results);
+                queue_results.push_back(aggregate_result);
+
+                queue_iterator = queue_multitasks.erase(queue_iterator);
+            }
+            else
+            {
+                ++queue_iterator;
+            }
+        }
     }
 
     bool update_slots() {
@@ -2596,7 +2713,7 @@ int main(int argc, char **argv)
     svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
             {
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, false, false);
+                const int task_id = llama.request_completion(data, false, false, -1);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
                     task_result result = llama.next_result(task_id);
@@ -2685,7 +2802,7 @@ int main(int argc, char **argv)
             {
                 json data = oaicompat_completion_params_parse(json::parse(req.body));
 
-                const int task_id = llama.request_completion(data, false, false);
+                const int task_id = llama.request_completion(data, false, false, -1);
 
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
@@ -2754,7 +2871,7 @@ int main(int argc, char **argv)
     svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
             {
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, true, false);
+                const int task_id = llama.request_completion(data, true, false, -1);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
                     task_result result = llama.next_result(task_id);
@@ -2858,7 +2975,7 @@ int main(int argc, char **argv)
                 {
                     prompt = "";
                 }
-                const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
+                const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
                 task_result result = llama.next_result(task_id);
                 return res.set_content(result.result_json.dump(), "application/json");
             });