]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : refactor multitask handling (#9274)
authorXuan Son Nguyen <redacted>
Mon, 2 Sep 2024 15:11:51 +0000 (17:11 +0200)
committerGitHub <redacted>
Mon, 2 Sep 2024 15:11:51 +0000 (17:11 +0200)
* server : remove multitask from server_task

* refactor completions handler

* fix embeddings

* use res_ok everywhere

* small change for handle_slots_action

* use unordered_set everywhere

* (try) fix test

* no more "mutable" lambda

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* use deque

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp
examples/server/tests/features/parallel.feature
examples/server/tests/features/steps/steps.py
examples/server/tests/features/wrong_usages.feature
examples/server/utils.hpp

index cc938e80d6a6d3da4922393c229c1aad1343564e..109dbc023efe0f6b9b1163bf1e8a10e3f2c521ad 100644 (file)
@@ -5,13 +5,6 @@
 #include "llama.h"
 #include "grammar-parser.h"
 
-#ifndef NDEBUG
-// crash the server in debug mode, otherwise send an http 500 error
-#define CPPHTTPLIB_NO_EXCEPTIONS 1
-#endif
-// increase max payload length to allow use of larger context size
-#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
-#include "httplib.h"
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 #include <chrono>
 #include <condition_variable>
 #include <cstddef>
-#include <set>
 #include <mutex>
 #include <thread>
 #include <signal.h>
 #include <memory>
+#include <unordered_set>
+#include <unordered_map>
+#include <deque>
 
 using json = nlohmann::ordered_json;
 
@@ -82,21 +77,33 @@ enum server_task_type {
     SERVER_TASK_TYPE_SET_LORA,
 };
 
+enum server_task_cmpl_type {
+    SERVER_TASK_CMPL_TYPE_NORMAL,
+    SERVER_TASK_CMPL_TYPE_EMBEDDING,
+    SERVER_TASK_CMPL_TYPE_INFILL,
+};
+
 struct server_task {
     int id        = -1; // to be filled by server_queue
-    int id_multi  = -1;
-    int id_target = -1;
+    int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
 
     server_task_type type;
     json data;
 
-    bool infill    = false;
-    bool embedding = false;
+    server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
+
+    // utility function
+    static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
+        std::unordered_set<int> ids(tasks.size());
+        for (size_t i = 0; i < tasks.size(); i++) {
+            ids.insert(tasks[i].id);
+        }
+        return ids;
+    }
 };
 
 struct server_task_result {
     int id       = -1;
-    int id_multi = -1;
 
     json data;
 
@@ -104,13 +111,6 @@ struct server_task_result {
     bool error;
 };
 
-struct server_task_multi {
-    int id = -1;
-
-    std::set<int> subtasks_remaining;
-    std::vector<server_task_result> results;
-};
-
 struct slot_params {
     bool stream       = true;
     bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
@@ -128,7 +128,9 @@ struct slot_params {
 struct server_slot {
     int id;
     int id_task = -1;
-    int id_multi = -1;
+
+    // the index relative to completion multi-task request
+    size_t index = 0;
 
     struct slot_params params;
 
@@ -158,8 +160,7 @@ struct server_slot {
     std::vector<llama_token> cache_tokens;
     std::vector<completion_token_output> generated_token_probs;
 
-    bool infill         = false;
-    bool embedding      = false;
+    server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
     bool has_next_token = true;
     bool truncated      = false;
     bool stopped_eos    = false;
@@ -204,7 +205,7 @@ struct server_slot {
         n_past             = 0;
         n_sent_text        = 0;
         n_sent_token_probs = 0;
-        infill             = false;
+        cmpl_type          = SERVER_TASK_CMPL_TYPE_NORMAL;
         ga_i               = 0;
         n_past_se          = 0;
 
@@ -383,38 +384,56 @@ struct server_queue {
     bool running;
 
     // queues
-    std::vector<server_task> queue_tasks;
-    std::vector<server_task> queue_tasks_deferred;
-
-    std::vector<server_task_multi> queue_multitasks;
+    std::deque<server_task> queue_tasks;
+    std::deque<server_task> queue_tasks_deferred;
 
     std::mutex mutex_tasks;
     std::condition_variable condition_tasks;
 
     // callback functions
-    std::function<void(server_task       &)> callback_new_task;
-    std::function<void(server_task_multi &)> callback_finish_multitask;
-    std::function<void(void)>                callback_update_slots;
+    std::function<void(server_task&)> callback_new_task;
+    std::function<void(void)>         callback_update_slots;
 
     // Add a new task to the end of the queue
-    int post(server_task task) {
+    int post(server_task task, bool front = false) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         if (task.id == -1) {
             task.id = id++;
             LOG_VERBOSE("new task id", {{"new_id", task.id}});
         }
-        queue_tasks.push_back(std::move(task));
+        if (front) {
+            queue_tasks.push_front(std::move(task));
+        } else {
+            queue_tasks.push_back(std::move(task));
+        }
         condition_tasks.notify_one();
         return task.id;
     }
 
+    // multi-task version of post()
+    int post(std::vector<server_task> & tasks, bool front = false) {
+        for (auto & task : tasks) {
+            if (task.id == -1) {
+                task.id = id++;
+                LOG_VERBOSE("new task id", {{"new_id", task.id}});
+            }
+            if (front) {
+                queue_tasks.push_front(std::move(task));
+            } else {
+                queue_tasks.push_back(std::move(task));
+            }
+        }
+        condition_tasks.notify_one();
+        return 0;
+    }
+
     // Add a new task, but defer until one slot is available
     void defer(server_task task) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         queue_tasks_deferred.push_back(std::move(task));
     }
 
-    // Get the next id for creating anew task
+    // Get the next id for creating a new task
     int get_new_id() {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         int new_id = id++;
@@ -427,11 +446,6 @@ struct server_queue {
         callback_new_task = std::move(callback);
     }
 
-    // Register function to process a multitask when it is finished
-    void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
-        callback_finish_multitask = std::move(callback);
-    }
-
     // Register the function to be called when all slots data is ready to be processed
     void on_update_slots(std::function<void(void)> callback) {
         callback_update_slots = std::move(callback);
@@ -480,22 +494,6 @@ struct server_queue {
                 callback_new_task(task);
             }
 
-            LOG_VERBOSE("update_multitasks", {});
-
-            // check if we have any finished multitasks
-            auto queue_iterator = queue_multitasks.begin();
-            while (queue_iterator != queue_multitasks.end()) {
-                if (queue_iterator->subtasks_remaining.empty()) {
-                    // all subtasks done == multitask is done
-                    server_task_multi current_multitask = *queue_iterator;
-                    callback_finish_multitask(current_multitask);
-                    // remove this multitask
-                    queue_iterator = queue_multitasks.erase(queue_iterator);
-                } else {
-                    ++queue_iterator;
-                }
-            }
-
             // all tasks in the current loop is processed, slots data is now ready
             LOG_VERBOSE("callback_update_slots", {});
 
@@ -516,38 +514,11 @@ struct server_queue {
             }
         }
     }
-
-    //
-    // functions to manage multitasks
-    //
-
-    // add a multitask by specifying the id of all subtask (subtask is a server_task)
-    void add_multitask(int id_multi, std::vector<int> & sub_ids) {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        server_task_multi multi;
-        multi.id = id_multi;
-        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
-        queue_multitasks.push_back(multi);
-    }
-
-    // updatethe remaining subtasks, while appending results to multitask
-    void update_multitask(int id_multi, int id_sub, server_task_result & result) {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        for (auto & multitask : queue_multitasks) {
-            if (multitask.id == id_multi) {
-                multitask.subtasks_remaining.erase(id_sub);
-                multitask.results.push_back(result);
-            }
-        }
-    }
 };
 
 struct server_response {
-    typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
-    callback_multitask_t callback_update_multitask;
-
     // for keeping track of all tasks waiting for the result
-    std::set<int> waiting_task_ids;
+    std::unordered_set<int> waiting_task_ids;
 
     // the main result queue
     std::vector<server_task_result> queue_results;
@@ -563,6 +534,12 @@ struct server_response {
         waiting_task_ids.insert(id_task);
     }
 
+    void add_waiting_tasks(const std::vector<server_task> & tasks) {
+        for (const auto & t : tasks) {
+            add_waiting_task_id(t.id);
+        }
+    }
+
     // when the request is finished, we can remove task associated with it
     void remove_waiting_task_id(int id_task) {
         LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
@@ -571,8 +548,8 @@ struct server_response {
         waiting_task_ids.erase(id_task);
     }
 
-    // This function blocks the thread until there is a response for this id_task
-    server_task_result recv(int id_task) {
+    // This function blocks the thread until there is a response for one of the id_tasks
+    server_task_result recv(const std::unordered_set<int> & id_tasks) {
         while (true) {
             std::unique_lock<std::mutex> lock(mutex_results);
             condition_results.wait(lock, [&]{
@@ -580,8 +557,7 @@ struct server_response {
             });
 
             for (int i = 0; i < (int) queue_results.size(); i++) {
-                if (queue_results[i].id == id_task) {
-                    assert(queue_results[i].id_multi == -1);
+                if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
                     server_task_result res = queue_results[i];
                     queue_results.erase(queue_results.begin() + i);
                     return res;
@@ -592,28 +568,21 @@ struct server_response {
         // should never reach here
     }
 
-    // Register the function to update multitask
-    void on_multitask_update(callback_multitask_t callback) {
-        callback_update_multitask = std::move(callback);
+    // single-task version of recv()
+    server_task_result recv(int id_task) {
+        std::unordered_set<int> id_tasks = {id_task};
+        return recv(id_tasks);
     }
 
     // Send a new result to a waiting id_task
-    void send(server_task_result result) {
+    void send(server_task_result result) {
         LOG_VERBOSE("send new result", {{"id_task", result.id}});
 
         std::unique_lock<std::mutex> lock(mutex_results);
         for (const auto & id_task : waiting_task_ids) {
-            // LOG_TEE("waiting task id %i \n", id_task);
-            // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
-            if (result.id_multi == id_task) {
-                LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
-                callback_update_multitask(id_task, result.id, result);
-                continue;
-            }
-
             if (result.id == id_task) {
                 LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
-                queue_results.push_back(result);
+                queue_results.push_back(std::move(result));
                 condition_results.notify_all();
                 return;
             }
@@ -966,7 +935,7 @@ struct server_context {
         slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
 
         // get prompt
-        if (!task.infill) {
+        if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
             const auto & prompt = data.find("prompt");
             if (prompt == data.end()) {
                 send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@@ -1359,23 +1328,21 @@ struct server_context {
     }
 
     void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
-        send_error(task.id, task.id_multi, error, type);
+        send_error(task.id, error, type);
     }
 
     void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
-        send_error(slot.id_task, slot.id_multi, error, type);
+        send_error(slot.id_task, error, type);
     }
 
-    void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
+    void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
         LOG_ERROR("task error", {
-            {"id_multi", id_multi},
             {"id_task", id_task},
             {"error", error},
         });
 
         server_task_result res;
         res.id       = id_task;
-        res.id_multi = id_multi;
         res.stop     = false;
         res.error    = true;
         res.data     = format_error_response(error, type);
@@ -1386,14 +1353,14 @@ struct server_context {
     void send_partial_response(server_slot & slot, completion_token_output tkn) {
         server_task_result res;
         res.id       = slot.id_task;
-        res.id_multi = slot.id_multi;
         res.error    = false;
         res.stop     = false;
         res.data     = json {
             {"content",    tkn.text_to_send},
             {"stop",       false},
             {"id_slot",    slot.id},
-            {"multimodal", false}
+            {"multimodal", false},
+            {"index",      slot.index},
         };
 
         if (slot.sparams.n_probs > 0) {
@@ -1423,7 +1390,6 @@ struct server_context {
     void send_final_response(const server_slot & slot) {
         server_task_result res;
         res.id       = slot.id_task;
-        res.id_multi = slot.id_multi;
         res.error    = false;
         res.stop     = true;
         res.data     = json {
@@ -1441,7 +1407,8 @@ struct server_context {
             {"stopped_limit",       slot.stopped_limit},
             {"stopping_word",       slot.stopping_word},
             {"tokens_cached",       slot.n_past},
-            {"timings",             slot.get_formated_timings()}
+            {"timings",             slot.get_formated_timings()},
+            {"index",               slot.index},
         };
 
         if (slot.sparams.n_probs > 0) {
@@ -1473,7 +1440,6 @@ struct server_context {
     void send_embedding(const server_slot & slot, const llama_batch & batch) {
         server_task_result res;
         res.id       = slot.id_task;
-        res.id_multi = slot.id_multi;
         res.error    = false;
         res.stop     = true;
 
@@ -1508,83 +1474,128 @@ struct server_context {
 
             res.data = json {
                 {"embedding", embd_res},
+                {"index",     slot.index},
             };
         }
 
         queue_results.send(res);
     }
 
-    void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
-        server_task task;
-        task.id        = id_task;
-        task.id_multi  = id_multi;
-        task.id_target = 0;
-        task.data      = std::move(data);
-        task.infill    = infill;
-        task.embedding = embedding;
-        task.type      = SERVER_TASK_TYPE_COMPLETION;
-
-        // when a completion task's prompt array is not a singleton, we split it into multiple requests
-        // otherwise, it's a single-prompt task, we actually queue it
-        // if there's numbers in the prompt array it will be treated as an array of tokens
-        if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
-            bool numbers = false;
-            for (const auto & e : task.data.at("prompt")) {
-                if (e.is_number()) {
-                    numbers = true;
-                    break;
-                }
-            }
+    //
+    // Functions to create new task(s) and receive result(s)
+    //
 
-            // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
-            // it will completely stall the server. I don't know where the bug for this is.
-            //
-            // if there are numbers, it needs to be treated like a single prompt,
-            // queue_tasks handles a mix of strings and numbers just fine.
-            if (numbers) {
-                queue_tasks.post(task);
+    std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
+        std::vector<server_task> tasks;
+        auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
+            server_task task;
+            task.id        = queue_tasks.get_new_id();
+            task.cmpl_type = cmpl_type;
+            task.type      = SERVER_TASK_TYPE_COMPLETION;
+            if (replace_prompt) {
+                task.data  = task_data;
+                task.data["prompt"] = prompt;
             } else {
-                split_multiprompt_task(id_task, task);
+                task.data  = std::move(task_data);
             }
-        } else {
-            queue_tasks.post(task);
+            tasks.push_back(std::move(task));
+        };
+
+        static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
+        if (!data.contains("prompt")) {
+            throw std::runtime_error(error_msg);
         }
-    }
 
-    void request_cancel(int id_task) {
-        server_task task;
-        task.type      = SERVER_TASK_TYPE_CANCEL;
-        task.id_target = id_task;
+        json prompt = data.at("prompt");
+
+        // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
+        if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
+            data["index"] = 0;
+            create_task(data, false, nullptr);
+        }
+        // otherwise, it's a multiple-prompt task, we break it into smaller tasks
+        else if (prompt.is_array()) {
+            std::vector<json> prompts = prompt;
+            for (size_t i = 0; i < prompts.size(); i++) {
+                const auto & e = prompts[i];
+                if (e.is_string() || json_is_array_of_numbers(e)) {
+                    data["index"] = i;
+                    create_task(data, true, e);
+                } else {
+                    throw std::runtime_error(error_msg);
+                }
+            }
+        }
+        // invalid case
+        else {
+            throw std::runtime_error(error_msg);
+        }
 
-        queue_tasks.post(task);
+        return tasks;
     }
 
-    void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
-        const int prompt_count = multiprompt_task.data.at("prompt").size();
-        if (prompt_count <= 1) {
-            send_error(multiprompt_task, "error while handling multiple prompts");
-            return;
-        }
+    void cancel_tasks(const std::unordered_set<int> & id_tasks) {
+        std::vector<server_task> cancel_tasks;
+        cancel_tasks.reserve(id_tasks.size());
+        for (const auto & id_task : id_tasks) {
+            LOG_VERBOSE("cancel task", {{"id_task", id_task}});
+            server_task task;
+            task.type      = SERVER_TASK_TYPE_CANCEL;
+            task.id_target = id_task;
+            cancel_tasks.push_back(task);
+            queue_results.remove_waiting_task_id(id_task);
+        }
+        // push to beginning of the queue, so it has highest priority
+        queue_tasks.post(cancel_tasks, true);
+    }
+
+    // receive the results from task(s) created by create_tasks_cmpl
+    void receive_cmpl_results(const std::unordered_set<int> & id_tasks, std::function<void(std::vector<server_task_result>&)> result_handler, std::function<void(json)> error_handler) {
+        // TODO: currently, there is no way to detect the client has cancelled the request
+        std::vector<server_task_result> results(id_tasks.size());
+        for (size_t i = 0; i < id_tasks.size(); i++) {
+            server_task_result result = queue_results.recv(id_tasks);
+
+            if (result.error) {
+                error_handler(result.data);
+                cancel_tasks(id_tasks);
+                break;
+            }
 
-        // generate all the ID for subtask
-        std::vector<int> subtask_ids(prompt_count);
-        for (int i = 0; i < prompt_count; i++) {
-            subtask_ids[i] = queue_tasks.get_new_id();
+            size_t idx = result.data["index"];
+            results[idx] = result;
         }
+        result_handler(results);
+    }
 
-        // queue up the multitask so we can track its subtask progression
-        queue_tasks.add_multitask(id_multi, subtask_ids);
+    // receive the results from task(s) created by create_tasks_cmpl, in stream mode
+    void receive_cmpl_results_stream(const std::unordered_set<int> & id_tasks, std::function<bool(server_task_result&)> result_handler, std::function<void(json)> error_handler) {
+        size_t n_finished = 0;
+        while (true) {
+            server_task_result result = queue_results.recv(id_tasks);
+            if (!result_handler(result)) {
+                cancel_tasks(id_tasks);
+                break;
+            }
 
-        // add subtasks
-        for (int i = 0; i < prompt_count; i++) {
-            json subtask_data = multiprompt_task.data;
-            subtask_data["prompt"] = subtask_data.at("prompt")[i];
+            if (result.error) {
+                error_handler(result.data);
+                cancel_tasks(id_tasks);
+                break;
+            }
 
-            // subtasks inherit everything else (infill mode, embedding mode, etc.)
-            request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
+            if (result.stop) {
+                if (++n_finished == id_tasks.size()) {
+                    break;
+                }
+            }
         }
     }
 
+    //
+    // Functions to process the task
+    //
+
     void process_single_task(const server_task & task) {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
@@ -1630,9 +1641,8 @@ struct server_context {
                     slot->reset();
 
                     slot->id_task   = task.id;
-                    slot->id_multi  = task.id_multi;
-                    slot->infill    = task.infill;
-                    slot->embedding = task.embedding;
+                    slot->cmpl_type = task.cmpl_type;
+                    slot->index     = json_value(task.data, "index", 0);
 
                     if (!launch_slot_with_task(*slot, task)) {
                         LOG_ERROR("error while launching slot", task.data);
@@ -1699,7 +1709,6 @@ struct server_context {
 
                     server_task_result res;
                     res.id       = task.id;
-                    res.id_multi = task.id_multi;
                     res.stop     = true;
                     res.error    = false;
                     res.data     = {
@@ -1861,26 +1870,6 @@ struct server_context {
         }
     }
 
-    void on_finish_multitask(const server_task_multi & multitask) {
-        // all subtasks done == multitask is done
-        server_task_result result;
-        result.id    = multitask.id;
-        result.stop  = true;
-        result.error = false;
-
-        // collect json results into one json result
-        std::vector<json> result_jsons;
-        for (const auto & subres : multitask.results) {
-            result_jsons.push_back(subres.data);
-            result.error = result.error && subres.error;
-        }
-        result.data = json {
-            { "results", result_jsons }
-        };
-
-        queue_results.send(result);
-    }
-
     void update_slots() {
         if (system_need_update) {
             system_prompt_update();
@@ -2038,7 +2027,7 @@ struct server_context {
                         slot.t_start_process_prompt = ggml_time_us();
                         slot.t_start_generation = 0;
 
-                        if (slot.infill) {
+                        if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
                             const bool add_bos = llama_add_bos_token(model);
                             bool suff_rm_leading_spc = true;
                             if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
@@ -2101,7 +2090,7 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.embedding) {
+                        if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.state = SLOT_STATE_PROCESSING;
@@ -2184,7 +2173,7 @@ struct server_context {
                         slot.n_prompt_tokens_processed = 0;
                     }
 
-                    if (slot.embedding) {
+                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
                         // cannot fit the prompt in the current batch - will try next iter
                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
                             continue;
@@ -2192,7 +2181,7 @@ struct server_context {
                     }
 
                     // check that we are in the right batch_type, if not defer the slot
-                    bool slot_type = slot.embedding ? 1 : 0;
+                    bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
                     if (batch_type == -1) {
                         batch_type = slot_type;
                     } else if (batch_type != slot_type) {
@@ -2385,7 +2374,7 @@ struct server_context {
                 }
 
                 // prompt evaluated for embedding
-                if (slot.embedding) {
+                if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
                     send_embedding(slot, batch_view);
                     slot.release();
                     slot.i_batch = -1;
@@ -2576,6 +2565,11 @@ int main(int argc, char ** argv) {
         res.status = json_value(error_data, "code", 500);
     };
 
+    auto res_ok = [](httplib::Response & res, json data) {
+        res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
+        res.status = 200;
+    };
+
     svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
         std::string message;
         try {
@@ -2623,7 +2617,7 @@ int main(int argc, char ** argv) {
 
     auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
         // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
-        static const std::set<std::string> protected_endpoints = {
+        static const std::unordered_set<std::string> protected_endpoints = {
             "/props",
             "/completion",
             "/completions",
@@ -2695,7 +2689,7 @@ int main(int argc, char ** argv) {
     const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
         // error and loading states are handled by middleware
         json health = {{"status", "ok"}};
-        res.set_content(health.dump(), "application/json");
+        res_ok(res, health);
     };
 
     const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
@@ -2707,8 +2701,6 @@ int main(int argc, char ** argv) {
         // request slots data using task queue
         server_task task;
         task.id = ctx_server.queue_tasks.get_new_id();
-        task.id_multi  = -1;
-        task.id_target = -1;
         task.type = SERVER_TASK_TYPE_METRICS;
 
         ctx_server.queue_results.add_waiting_task_id(task.id);
@@ -2727,8 +2719,7 @@ int main(int argc, char ** argv) {
             }
         }
 
-        res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
-        res.status = 200; // HTTP OK
+        res_ok(res, result.data.at("slots"));
     };
 
     const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
@@ -2740,7 +2731,6 @@ int main(int argc, char ** argv) {
         // request slots data using task queue
         server_task task;
         task.id = ctx_server.queue_tasks.get_new_id();
-        task.id_multi  = -1;
         task.id_target = -1;
         task.type = SERVER_TASK_TYPE_METRICS;
         task.data.push_back({{"reset_bucket", true}});
@@ -2832,7 +2822,7 @@ int main(int argc, char ** argv) {
         res.status = 200; // HTTP OK
     };
 
-    const auto handle_slots_save = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
+    const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
         json request_data = json::parse(req.body);
         std::string filename = request_data.at("filename");
         if (!fs_validate_filename(filename)) {
@@ -2858,11 +2848,11 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), MIMETYPE_JSON);
+            res_ok(res, result.data);
         }
     };
 
-    const auto handle_slots_restore = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
+    const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
         json request_data = json::parse(req.body);
         std::string filename = request_data.at("filename");
         if (!fs_validate_filename(filename)) {
@@ -2888,11 +2878,11 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), MIMETYPE_JSON);
+            res_ok(res, result.data);
         }
     };
 
-    const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
+    const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
         server_task task;
         task.type = SERVER_TASK_TYPE_SLOT_ERASE;
         task.data = {
@@ -2908,11 +2898,16 @@ int main(int argc, char ** argv) {
         if (result.error) {
             res_error(res, result.data);
         } else {
-            res.set_content(result.data.dump(), MIMETYPE_JSON);
+            res_ok(res, result.data);
         }
     };
 
-    const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
+        if (params.slot_save_path.empty()) {
+            res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
+
         std::string id_slot_str = req.path_params.at("id_slot");
         int id_slot;
 
@@ -2936,7 +2931,7 @@ int main(int argc, char ** argv) {
         }
     };
 
-    const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
+    const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
         std::string template_key = "tokenizer.chat_template", curr_tmpl;
         int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
         if (tlen > 0) {
@@ -2952,244 +2947,128 @@ int main(int argc, char ** argv) {
             { "chat_template",               curr_tmpl.c_str() }
         };
 
-        res.set_content(data.dump(), MIMETYPE_JSON);
+        res_ok(res, data);
     };
 
-    const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
         if (ctx_server.params.embedding) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
-        json data = json::parse(req.body);
-
-        const int id_task = ctx_server.queue_tasks.get_new_id();
-
-        ctx_server.queue_results.add_waiting_task_id(id_task);
-        ctx_server.request_completion(id_task, -1, data, false, false);
-
-        if (!json_value(data, "stream", false)) {
-            server_task_result result = ctx_server.queue_results.recv(id_task);
-            if (!result.error && result.stop) {
-                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
-            } else {
-                res_error(res, result.data);
-            }
-
-            ctx_server.queue_results.remove_waiting_task_id(id_task);
-        } else {
-            const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
-                while (true) {
-                    server_task_result result = ctx_server.queue_results.recv(id_task);
-                    if (!result.error) {
-                        const std::string str =
-                            "data: " +
-                            result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
-                            "\n\n";
-
-                        LOG_VERBOSE("data stream", {
-                            { "to_send", str }
-                        });
-
-                        if (!sink.write(str.c_str(), str.size())) {
-                            ctx_server.queue_results.remove_waiting_task_id(id_task);
-                            return false;
-                        }
-
-                        if (result.stop) {
-                            break;
-                        }
-                    } else {
-                        const std::string str =
-                            "error: " +
-                            result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
-                            "\n\n";
-
-                        LOG_VERBOSE("data stream", {
-                            { "to_send", str }
-                        });
+        std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
+        ctx_server.queue_results.add_waiting_tasks(tasks);
+        ctx_server.queue_tasks.post(tasks);
 
-                        if (!sink.write(str.c_str(), str.size())) {
-                            ctx_server.queue_results.remove_waiting_task_id(id_task);
-                            return false;
-                        }
+        bool stream = json_value(data, "stream", false);
+        const auto task_ids = server_task::get_list_id(tasks);
 
-                        break;
+        if (!stream) {
+            ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
+                if (results.size() == 1) {
+                    // single result
+                    res_ok(res, results[0].data);
+                } else {
+                    // multiple results (multitask)
+                    json arr = json::array();
+                    for (const auto & res : results) {
+                        arr.push_back(res.data);
                     }
+                    res_ok(res, arr);
                 }
-
-                ctx_server.queue_results.remove_waiting_task_id(id_task);
+            }, [&](json error_data) {
+                res_error(res, error_data);
+            });
+        } else {
+            const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
+                ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
+                    return server_sent_event(sink, "data", result.data);
+                }, [&](json error_data) {
+                    server_sent_event(sink, "error", error_data);
+                });
                 sink.done();
-
-                return true;
-            };
-
-            auto on_complete = [id_task, &ctx_server] (bool) {
-                // cancel
-                ctx_server.request_cancel(id_task);
-                ctx_server.queue_results.remove_waiting_task_id(id_task);
+                return false;
             };
-
-            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
+            res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
         }
     };
 
-    const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
-        json models = {
-            {"object", "list"},
-            {"data", {
-                 {
-                     {"id",       params.model_alias},
-                     {"object",   "model"},
-                     {"created",  std::time(0)},
-                     {"owned_by", "llamacpp"},
-                     {"meta",     ctx_server.model_meta()}
-                 },
-             }}
-        };
+    const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
+        json data = json::parse(req.body);
+        return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
+    };
 
-        res.set_content(models.dump(), MIMETYPE_JSON);
+    const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
+        json data = json::parse(req.body);
+        return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
     };
 
-    const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
+    // TODO: maybe merge this function with "handle_completions_generic"
+    const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
         if (ctx_server.params.embedding) {
-            res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
-        json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
 
-        const int id_task = ctx_server.queue_tasks.get_new_id();
+        json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
 
-        ctx_server.queue_results.add_waiting_task_id(id_task);
-        ctx_server.request_completion(id_task, -1, data, false, false);
+        std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
+        ctx_server.queue_results.add_waiting_tasks(tasks);
+        ctx_server.queue_tasks.post(tasks);
 
+        bool stream = json_value(data, "stream", false);
+        const auto task_ids = server_task::get_list_id(tasks);
         const auto completion_id = gen_chatcmplid();
-        if (!json_value(data, "stream", false)) {
-            server_task_result result = ctx_server.queue_results.recv(id_task);
-
-            if (!result.error && result.stop) {
-                json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
 
-                res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
-            } else {
-                res_error(res, result.data);
-            }
-            ctx_server.queue_results.remove_waiting_task_id(id_task);
+        if (!stream) {
+            ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
+                // multitask is never support in chat completion, there is only one result
+                json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id);
+                res_ok(res, result_oai);
+            }, [&](json error_data) {
+                res_error(res, error_data);
+            });
         } else {
-            const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
-                while (true) {
-                    server_task_result result = ctx_server.queue_results.recv(id_task);
-                    if (!result.error) {
-                        std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
-
-                        for (auto it = result_array.begin(); it != result_array.end(); ++it) {
-                            if (!it->empty()) {
-                                const std::string str =
-                                    "data: " +
-                                    it->dump(-1, ' ', false, json::error_handler_t::replace) +
-                                    "\n\n";
-                                LOG_VERBOSE("data stream", {{"to_send", str}});
-                                if (!sink.write(str.c_str(), str.size())) {
-                                    ctx_server.queue_results.remove_waiting_task_id(id_task);
-                                    return false;
-                                }
-                            }
+            const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
+                ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool {
+                    std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
+                    for (auto & event_data : result_array) {
+                        if (event_data.empty()) {
+                            continue; // skip the stop token
                         }
-                        if (result.stop) {
-                            break;
+                        if (!server_sent_event(sink, "data", event_data)) {
+                            return false; // connection is closed
                         }
-                    } else {
-                        const std::string str =
-                            "error: " +
-                            result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
-                            "\n\n";
-                        LOG_VERBOSE("data stream", {{"to_send", str}});
-                        if (!sink.write(str.c_str(), str.size())) {
-                            ctx_server.queue_results.remove_waiting_task_id(id_task);
-                            return false;
-                        }
-                        break;
                     }
-                }
+                    return true; // ok
+                }, [&](json error_data) {
+                    server_sent_event(sink, "error", error_data);
+                });
                 sink.done();
-                ctx_server.queue_results.remove_waiting_task_id(id_task);
                 return true;
             };
-
-            auto on_complete = [id_task, &ctx_server](bool) {
-                // cancel request
-                ctx_server.request_cancel(id_task);
-                ctx_server.queue_results.remove_waiting_task_id(id_task);
-            };
-
-            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
+            res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
         }
     };
 
-    const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
-        if (ctx_server.params.embedding) {
-            res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
-            return;
-        }
-
-        json data = json::parse(req.body);
-
-        const int id_task = ctx_server.queue_tasks.get_new_id();
-
-        ctx_server.queue_results.add_waiting_task_id(id_task);
-        ctx_server.request_completion(id_task, -1, data, true, false);
-
-        if (!json_value(data, "stream", false)) {
-            server_task_result result = ctx_server.queue_results.recv(id_task);
-            if (!result.error && result.stop) {
-                res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
-            } else {
-                res_error(res, result.data);
-            }
-
-            ctx_server.queue_results.remove_waiting_task_id(id_task);
-        } else {
-            const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
-                while (true) {
-                    server_task_result result = ctx_server.queue_results.recv(id_task);
-                    if (!result.error) {
-                        const std::string str =
-                            "data: " +
-                            result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
-                            "\n\n";
-
-                        LOG_VERBOSE("data stream", {
-                            { "to_send", str }
-                        });
-
-                        if (!sink.write(str.c_str(), str.size())) {
-                            ctx_server.queue_results.remove_waiting_task_id(id_task);
-                            return false;
-                        }
-
-                        if (result.stop) {
-                            break;
-                        }
-                    } else {
-                        break;
-                    }
-                }
-
-                ctx_server.queue_results.remove_waiting_task_id(id_task);
-                sink.done();
-
-                return true;
-            };
-
-            auto on_complete = [id_task, &ctx_server] (bool) {
-                ctx_server.request_cancel(id_task);
-            };
+    const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
+        json models = {
+            {"object", "list"},
+            {"data", {
+                 {
+                     {"id",       params.model_alias},
+                     {"object",   "model"},
+                     {"created",  std::time(0)},
+                     {"owned_by", "llamacpp"},
+                     {"meta",     ctx_server.model_meta()}
+                 },
+             }}
+        };
 
-            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
-        }
+        res.set_content(models.dump(), MIMETYPE_JSON);
     };
 
-    const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
         const json body = json::parse(req.body);
 
         std::vector<llama_token> tokens;
@@ -3198,10 +3077,10 @@ int main(int argc, char ** argv) {
             tokens = ctx_server.tokenize(body.at("content"), add_special);
         }
         const json data = format_tokenizer_response(tokens);
-        return res.set_content(data.dump(), MIMETYPE_JSON);
+        res_ok(res, data);
     };
 
-    const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
         const json body = json::parse(req.body);
 
         std::string content;
@@ -3211,10 +3090,10 @@ int main(int argc, char ** argv) {
         }
 
         const json data = format_detokenized_response(content);
-        return res.set_content(data.dump(), MIMETYPE_JSON);
+        res_ok(res, data);
     };
 
-    const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
         const json body = json::parse(req.body);
         bool is_openai = false;
 
@@ -3232,35 +3111,35 @@ int main(int argc, char ** argv) {
         }
 
         // create and queue the task
-        json responses;
+        json responses = json::array();
+        bool error = false;
         {
-            const int id_task = ctx_server.queue_tasks.get_new_id();
-            ctx_server.queue_results.add_waiting_task_id(id_task);
-            ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
+            std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
+            ctx_server.queue_results.add_waiting_tasks(tasks);
+            ctx_server.queue_tasks.post(tasks);
 
             // get the result
-            server_task_result result = ctx_server.queue_results.recv(id_task);
-            ctx_server.queue_results.remove_waiting_task_id(id_task);
-            if (!result.error) {
-                if (result.data.count("results")) {
-                    // result for multi-task
-                    responses = result.data.at("results");
-                } else {
-                    // result for single task
-                    responses = std::vector<json>{result.data};
+            std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
+
+            ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
+                for (const auto & res : results) {
+                    responses.push_back(res.data);
                 }
-            } else {
-                // error received, ignore everything else
-                res_error(res, result.data);
-                return;
-            }
+            }, [&](json error_data) {
+                res_error(res, error_data);
+                error = true;
+            });
+        }
+
+        if (error) {
+            return;
         }
 
         // write JSON response
         json root = is_openai
             ? format_embeddings_response_oaicompat(body, responses)
             : responses[0];
-        return res.set_content(root.dump(), MIMETYPE_JSON);
+        res_ok(res, root);
     };
 
     const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
@@ -3273,7 +3152,7 @@ int main(int argc, char ** argv) {
                 {"scale", la.scale},
             });
         }
-        res.set_content(result.dump(), MIMETYPE_JSON);
+        res_ok(res, result);
         res.status = 200; // HTTP OK
     };
 
@@ -3305,7 +3184,7 @@ int main(int argc, char ** argv) {
         server_task_result result = ctx_server.queue_results.recv(id_task);
         ctx_server.queue_results.remove_waiting_task_id(id_task);
 
-        res.set_content(result.data.dump(), MIMETYPE_JSON);
+        res_ok(res, result.data);
         res.status = 200; // HTTP OK
     };
 
@@ -3366,10 +3245,7 @@ int main(int argc, char ** argv) {
     svr->Post("/lora-adapters",       handle_lora_adapters_apply);
     // Save & load slots
     svr->Get ("/slots",               handle_slots);
-    if (!params.slot_save_path.empty()) {
-        // only enable slot endpoints if slot_save_path is set
-        svr->Post("/slots/:id_slot",  handle_slots_action);
-    }
+    svr->Post("/slots/:id_slot",      handle_slots_action);
 
     //
     // Start the server
@@ -3433,17 +3309,8 @@ int main(int argc, char ** argv) {
 
         ctx_server.queue_tasks.on_new_task(std::bind(
             &server_context::process_single_task, &ctx_server, std::placeholders::_1));
-        ctx_server.queue_tasks.on_finish_multitask(std::bind(
-            &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
         ctx_server.queue_tasks.on_update_slots(std::bind(
             &server_context::update_slots, &ctx_server));
-        ctx_server.queue_results.on_multitask_update(std::bind(
-            &server_queue::update_multitask,
-            &ctx_server.queue_tasks,
-            std::placeholders::_1,
-            std::placeholders::_2,
-            std::placeholders::_3
-        ));
 
         shutdown_handler = [&](int) {
             ctx_server.queue_tasks.terminate();
index 6cd306a2bcf7c9b5a7eb56550628c8a2e82e5f2c..96a96d6f8f7d30e830a4306f149c3afb21e08c2c 100644 (file)
@@ -52,8 +52,8 @@ Feature: Parallel
     Then all prompts are predicted with <n_predict> tokens
     Examples:
       | streaming | n_predict |
-      | disabled  | 128       |
-      | enabled   | 64        |
+      | disabled  | 200       |
+      | enabled   | 200       |
 
   Scenario Outline: Multi users OAI completions compatibility no v1
     Given a system prompt You are a writer.
index 1ba7b60b69c46f79dcb88df6f12e3556c9887895..1864a694fc94a6299e761cdebbd6704ff996cfa3 100644 (file)
@@ -818,7 +818,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
     for prompt_no in range(context.n_prompts):
         shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
-    await asyncio.sleep(0.1)
+    await asyncio.sleep(0.01)
 
 
 @step('the slot {slot_id:d} is saved with filename "{filename}"')
index cf14b3b44e03b308f889f794563148ff6a925b32..61d5f315e15670e60f22516d96380c3cb0be9c2e 100644 (file)
@@ -8,9 +8,12 @@ Feature: Wrong usage of llama.cpp server
   Scenario: Infinite loop
     Given a server listening on localhost:8080
     And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+    And   42 as server seed
+    And   2048 KV cache size
     # Uncomment below to fix the issue
     #And   64 server max tokens to predict
     Then  the server is starting
+    Then  the server is healthy
     Given a prompt:
       """
       Go to: infinite loop
index e6a1f069723ec5f33a15325df4688a2e4a175b36..edfce65b634e061ce6faa41e07b232e54891edf2 100644 (file)
@@ -3,6 +3,14 @@
 #include "llama.h"
 #include "common.h"
 
+#ifndef NDEBUG
+// crash the server in debug mode, otherwise send an http 500 error
+#define CPPHTTPLIB_NO_EXCEPTIONS 1
+#endif
+// increase max payload length to allow use of larger context size
+#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
+#include "httplib.h"
+
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
@@ -279,6 +287,18 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
     return std::string::npos;
 }
 
+static bool json_is_array_of_numbers(json data) {
+    if (data.is_array()) {
+        for (const auto & e : data) {
+            if (!e.is_number()) {
+                return false;
+            }
+        }
+        return true;
+    }
+    return false;
+}
+
 // TODO: reuse llama_detokenize
 template <class Iter>
 static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -343,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
     return out;
 }
 
+static bool server_sent_event(httplib::DataSink & sink, const char * event, json & data) {
+    const std::string str =
+        std::string(event) + ": " +
+        data.dump(-1, ' ', false, json::error_handler_t::replace) +
+        "\n\n";
+
+    LOG_VERBOSE("data stream", {
+        { "to_send", str }
+    });
+
+    return sink.write(str.c_str(), str.size());
+}
+
 //
 // OAI utils
 //