]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : use std::move whenever possible (#12936)
authorXuan-Son Nguyen <redacted>
Fri, 18 Apr 2025 17:58:12 +0000 (19:58 +0200)
committerGitHub <redacted>
Fri, 18 Apr 2025 17:58:12 +0000 (19:58 +0200)
* server : use std::move whenever possible

* use r-value ref

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* make task creation scoped

* restore std::move

* fix task_id not set correctly

* apply changes from suggestion

Co-authored-by: ggerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp

index d87bda1a0d5af97a79c967a44f3726d042d4dcc2..c580ec123299cebcbd61db562fa9f76a2f5044e5 100644 (file)
@@ -1552,29 +1552,30 @@ struct server_queue {
     std::condition_variable condition_tasks;
 
     // callback functions
-    std::function<void(server_task)> callback_new_task;
-    std::function<void(void)>        callback_update_slots;
+    std::function<void(server_task &&)> callback_new_task;
+    std::function<void(void)>           callback_update_slots;
 
     // Add a new task to the end of the queue
-    int post(server_task task, bool front = false) {
+    int post(server_task && task, bool front = false) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         GGML_ASSERT(task.id != -1);
         // if this is cancel task make sure to clean up pending tasks
         if (task.type == SERVER_TASK_TYPE_CANCEL) {
             cleanup_pending_task(task.id_target);
         }
-        QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
+        const int task_id = task.id;
+        QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
         if (front) {
             queue_tasks.push_front(std::move(task));
         } else {
             queue_tasks.push_back(std::move(task));
         }
         condition_tasks.notify_one();
-        return task.id;
+        return task_id;
     }
 
     // multi-task version of post()
-    int post(std::vector<server_task> & tasks, bool front = false) {
+    int post(std::vector<server_task> && tasks, bool front = false) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         for (auto & task : tasks) {
             if (task.id == -1) {
@@ -1596,7 +1597,7 @@ struct server_queue {
     }
 
     // Add a new task, but defer until one slot is available
-    void defer(server_task task) {
+    void defer(server_task && task) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         QUE_DBG("defer task, id = %d\n", task.id);
         queue_tasks_deferred.push_back(std::move(task));
@@ -1611,7 +1612,7 @@ struct server_queue {
     }
 
     // Register function to process a new task
-    void on_new_task(std::function<void(server_task)> callback) {
+    void on_new_task(std::function<void(server_task &&)> callback) {
         callback_new_task = std::move(callback);
     }
 
@@ -1660,7 +1661,7 @@ struct server_queue {
                     lock.unlock();
                     break;
                 }
-                server_task task = queue_tasks.front();
+                server_task task = std::move(queue_tasks.front());
                 queue_tasks.pop_front();
                 lock.unlock();
 
@@ -2004,7 +2005,7 @@ struct server_context {
 
             slot.reset();
 
-            slots.push_back(slot);
+            slots.push_back(std::move(slot));
         }
 
         default_generation_settings_for_props = slots[0].to_json();
@@ -2105,7 +2106,7 @@ struct server_context {
         return true;
     }
 
-    bool launch_slot_with_task(server_slot & slot, const server_task & task) {
+    bool launch_slot_with_task(server_slot & slot, server_task && task) {
         slot.reset();
         slot.id_task       = task.id;
         slot.index         = task.index;
@@ -2113,10 +2114,10 @@ struct server_context {
         slot.params        = std::move(task.params);
         slot.prompt_tokens = std::move(task.prompt_tokens);
 
-        if (!are_lora_equal(task.params.lora, slot.lora)) {
+        if (!are_lora_equal(slot.params.lora, slot.lora)) {
             // if lora is changed, we cannot reuse cached tokens
             slot.cache_tokens.clear();
-            slot.lora = task.params.lora;
+            slot.lora = slot.params.lora;
         }
 
         bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
@@ -2547,10 +2548,10 @@ struct server_context {
             server_task task(SERVER_TASK_TYPE_CANCEL);
             task.id_target = id_task;
             queue_results.remove_waiting_task_id(id_task);
-            cancel_tasks.push_back(task);
+            cancel_tasks.push_back(std::move(task));
         }
         // push to beginning of the queue, so it has highest priority
-        queue_tasks.post(cancel_tasks, true);
+        queue_tasks.post(std::move(cancel_tasks), true);
     }
 
     // receive the results from task(s)
@@ -2637,7 +2638,7 @@ struct server_context {
     // Functions to process the task
     //
 
-    void process_single_task(server_task task) {
+    void process_single_task(server_task && task) {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
             case SERVER_TASK_TYPE_INFILL:
@@ -2651,17 +2652,17 @@ struct server_context {
                     if (slot == nullptr) {
                         // if no slot is available, we defer this task for processing later
                         SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
-                    if (!launch_slot_with_task(*slot, task)) {
+                    if (!launch_slot_with_task(*slot, std::move(task))) {
                         SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
                         break;
                     }
@@ -2740,7 +2741,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2776,7 +2777,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2819,7 +2820,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2871,7 +2872,7 @@ struct server_context {
 
             server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
             task.id = queue_tasks.get_new_id();
-            queue_tasks.post(task);
+            queue_tasks.post(std::move(task));
         }
 
         // apply context-shift if needed
@@ -3633,14 +3634,17 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task(SERVER_TASK_TYPE_METRICS);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task, true); // high-priority task
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_METRICS);
+            task.id = task_id;
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+        }
 
         // get the result
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3669,16 +3673,17 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task(SERVER_TASK_TYPE_METRICS);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.metrics_reset_bucket = true;
-
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task, true); // high-priority task
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_METRICS);
+            task.id = task_id;
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+        }
 
         // get the result
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3775,17 +3780,20 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id  = id_slot;
-        task.slot_action.filename = filename;
-        task.slot_action.filepath = filepath;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+            task.id = task_id;
+            task.slot_action.slot_id  = id_slot;
+            task.slot_action.filename = filename;
+            task.slot_action.filepath = filepath;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3804,17 +3812,20 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id  = id_slot;
-        task.slot_action.filename = filename;
-        task.slot_action.filepath = filepath;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+            task.id = task_id;
+            task.slot_action.slot_id  = id_slot;
+            task.slot_action.filename = filename;
+            task.slot_action.filepath = filepath;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3826,15 +3837,18 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
-        server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id = id_slot;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+            task.id = task_id;
+            task.slot_action.slot_id = id_slot;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3938,9 +3952,10 @@ int main(int argc, char ** argv) {
         }
 
         auto completion_id = gen_chatcmplid();
-        std::vector<server_task> tasks;
-
+        std::unordered_set<int> task_ids;
         try {
+            std::vector<server_task> tasks;
+
             const auto & prompt = data.at("prompt");
             // TODO: this log can become very long, put it behind a flag or think about a more compact format
             //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -3955,9 +3970,9 @@ int main(int argc, char ** argv) {
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
                 task.params           = server_task::params_from_json_cmpl(
-                                            ctx_server.ctx,
-                                            ctx_server.params_base,
-                                            data);
+                        ctx_server.ctx,
+                        ctx_server.params_base,
+                        data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat
@@ -3965,18 +3980,18 @@ int main(int argc, char ** argv) {
                 task.params.oaicompat_cmpl_id         = completion_id;
                 // oaicompat_model is already populated by params_from_json_cmpl
 
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
+
+            task_ids = server_task::get_list_id(tasks);
+            ctx_server.queue_results.add_waiting_tasks(tasks);
+            ctx_server.queue_tasks.post(std::move(tasks));
         } catch (const std::exception & e) {
             res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
             return;
         }
 
-        ctx_server.queue_results.add_waiting_tasks(tasks);
-        ctx_server.queue_tasks.post(tasks);
-
         bool stream = json_value(data, "stream", false);
-        const auto task_ids = server_task::get_list_id(tasks);
 
         if (!stream) {
             ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
@@ -4268,6 +4283,7 @@ int main(int argc, char ** argv) {
         // create and queue the task
         json responses = json::array();
         bool error = false;
+        std::unordered_set<int> task_ids;
         {
             std::vector<server_task> tasks;
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -4280,27 +4296,26 @@ int main(int argc, char ** argv) {
                 // OAI-compat
                 task.params.oaicompat = oaicompat;
 
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
 
+            task_ids = server_task::get_list_id(tasks);
             ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(tasks);
-
-            // get the result
-            std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
+            ctx_server.queue_tasks.post(std::move(tasks));
+        }
 
-            ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
-                for (auto & res : results) {
-                    GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
-                    responses.push_back(res->to_json());
-                }
-            }, [&](const json & error_data) {
-                res_error(res, error_data);
-                error = true;
-            }, req.is_connection_closed);
+        // get the result
+        ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
+            for (auto & res : results) {
+                GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
+                responses.push_back(res->to_json());
+            }
+        }, [&](const json & error_data) {
+            res_error(res, error_data);
+            error = true;
+        }, req.is_connection_closed);
 
-            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
-        }
+        ctx_server.queue_results.remove_waiting_task_ids(task_ids);
 
         if (error) {
             return;
@@ -4367,6 +4382,7 @@ int main(int argc, char ** argv) {
         // create and queue the task
         json responses = json::array();
         bool error = false;
+        std::unordered_set<int> task_ids;
         {
             std::vector<server_task> tasks;
             std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
@@ -4376,26 +4392,24 @@ int main(int argc, char ** argv) {
                 task.id            = ctx_server.queue_tasks.get_new_id();
                 task.index         = i;
                 task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
 
+            task_ids = server_task::get_list_id(tasks);
             ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(tasks);
-
-            // get the result
-            std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
-
-            ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
-                for (auto & res : results) {
-                    GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
-                    responses.push_back(res->to_json());
-                }
-            }, [&](const json & error_data) {
-                res_error(res, error_data);
-                error = true;
-            }, req.is_connection_closed);
+            ctx_server.queue_tasks.post(std::move(tasks));
         }
 
+        ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
+            for (auto & res : results) {
+                GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
+                responses.push_back(res->to_json());
+            }
+        }, [&](const json & error_data) {
+            res_error(res, error_data);
+            error = true;
+        }, req.is_connection_closed);
+
         if (error) {
             return;
         }
@@ -4431,14 +4445,19 @@ int main(int argc, char ** argv) {
             res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
-        server_task task(SERVER_TASK_TYPE_SET_LORA);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SET_LORA);
+            task.id = task_id;
+            task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
+
+        // get the result
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -4582,8 +4601,8 @@ int main(int argc, char ** argv) {
         common_chat_templates_source(ctx_server.chat_templates.get()),
         common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
 
-    ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
-        ctx_server.process_single_task(task);
+    ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
+        ctx_server.process_single_task(std::move(task));
     });
 
     ctx_server.queue_tasks.on_update_slots([&ctx_server]() {