]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : refactored the task processing logic (#5065)
authorXuan Son Nguyen <redacted>
Fri, 26 Jan 2024 12:42:20 +0000 (13:42 +0100)
committerGitHub <redacted>
Fri, 26 Jan 2024 12:42:20 +0000 (14:42 +0200)
* server: add llama_server_queue struct

* server: add llama_server_response_event

* server: add comments

* server: move all mutexes away from server.cpp

* server: correct multitask response

* server: only add back deferred tasks when one slot is available

* server: fix a race condition cause by "request_completion"

Makefile
examples/server/CMakeLists.txt
examples/server/oai.hpp [new file with mode: 0644]
examples/server/server.cpp
examples/server/utils.hpp [new file with mode: 0644]

index 82c89e87a9927d077107893d836f04d91d2b1166..b8858b412433e1cf9b9588170ef9af8c9535dd08 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -619,7 +619,7 @@ embedding: examples/embedding/embedding.cpp                   ggml.o llama.o $(C
 save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
        $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
-server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
        $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual
 
 gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
index 81709e4484c9f71690f7102ce9ba72579007ee15..cc13b2d6306520bbbdab7a0253cd433fa1a476f0 100644 (file)
@@ -1,7 +1,7 @@
 set(TARGET server)
 option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
-add_executable(${TARGET} server.cpp json.hpp httplib.h)
+add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h)
 install(TARGETS ${TARGET} RUNTIME)
 target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp
new file mode 100644 (file)
index 0000000..bc5db6e
--- /dev/null
@@ -0,0 +1,208 @@
+#pragma once
+
+#include <string>
+#include <vector>
+#include <set>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+
+#include "json.hpp"
+#include "utils.hpp"
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
+
+using json = nlohmann::json;
+
+inline static json oaicompat_completion_params_parse(
+    const json &body /* openai api json semantics */)
+{
+    json llama_params;
+
+    llama_params["__oaicompat"] = true;
+
+    // Map OpenAI parameters to llama.cpp parameters
+    //
+    // For parameters that are defined by the OpenAI documentation (e.g.
+    // temperature), we explicitly specify OpenAI's intended default; we
+    // need to do that because sometimes OpenAI disagrees with llama.cpp
+    //
+    // https://platform.openai.com/docs/api-reference/chat/create
+    llama_sampling_params default_sparams;
+    llama_params["model"]             = json_value(body, "model", std::string("unknown"));
+    llama_params["prompt"]            = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
+    llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false);
+    llama_params["temperature"]       = json_value(body, "temperature", 0.0);
+    llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k);
+    llama_params["top_p"]             = json_value(body, "top_p", 1.0);
+    llama_params["n_predict"]         = json_value(body, "max_tokens", -1);
+    llama_params["logit_bias"]        = json_value(body, "logit_bias",json::object());
+    llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
+    llama_params["presence_penalty"]  = json_value(body, "presence_penalty", 0.0);
+    llama_params["seed"]              = json_value(body, "seed", LLAMA_DEFAULT_SEED);
+    llama_params["stream"]            = json_value(body, "stream", false);
+    llama_params["mirostat"]          = json_value(body, "mirostat", default_sparams.mirostat);
+    llama_params["mirostat_tau"]      = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
+    llama_params["mirostat_eta"]      = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
+    llama_params["penalize_nl"]       = json_value(body, "penalize_nl", default_sparams.penalize_nl);
+    llama_params["typical_p"]         = json_value(body, "typical_p", default_sparams.typical_p);
+    llama_params["repeat_last_n"]     = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
+    llama_params["ignore_eos"]        = json_value(body, "ignore_eos", false);
+    llama_params["tfs_z"]             = json_value(body, "tfs_z", default_sparams.tfs_z);
+
+    if (body.count("grammar") != 0) {
+        llama_params["grammar"] = json_value(body, "grammar", json::object());
+    }
+
+    // Handle 'stop' field
+    if (body.contains("stop") && body["stop"].is_string()) {
+        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
+    } else {
+        llama_params["stop"] = json_value(body, "stop", json::array());
+    }
+
+    // Ensure there is ChatML-specific end sequence among stop words
+    llama_params["stop"].push_back("<|im_end|>");
+
+    return llama_params;
+}
+
+inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
+{
+    json result = response.result_json;
+
+    bool stopped_word        = result.count("stopped_word") != 0;
+    bool stopped_eos         = json_value(result, "stopped_eos", false);
+    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
+    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
+    std::string content      = json_value(result, "content", std::string(""));
+
+    std::string finish_reason = "length";
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+
+    json choices =
+        streaming ? json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"delta", json::object()}}})
+                  : json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"message", json{{"content", content},
+                                                         {"role", "assistant"}}}}});
+
+    std::time_t t = std::time(0);
+
+    json res =
+        json{{"choices", choices},
+            {"created", t},
+            {"model",
+                json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+            {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
+            {"usage",
+                json{{"completion_tokens", num_tokens_predicted},
+                     {"prompt_tokens",     num_prompt_tokens},
+                     {"total_tokens",      num_tokens_predicted + num_prompt_tokens}}},
+            {"id", gen_chatcmplid()}};
+
+    if (server_verbose) {
+        res["__verbose"] = result;
+    }
+
+    if (result.contains("completion_probabilities")) {
+        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
+    }
+
+    return res;
+}
+
+// return value is vector as there is one case where we might need to generate two responses
+inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
+    json result = response.result_json;
+
+    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
+        return std::vector<json>({response.result_json});
+    }
+
+    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
+    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+
+    bool stopped_word   = json_value(result, "stopped_word", false);
+    bool stopped_eos    = json_value(result, "stopped_eos", false);
+    bool stopped_limit  = json_value(result, "stopped_limit", false);
+    std::string content = json_value(result, "content", std::string(""));
+
+    std::string finish_reason;
+    if (stopped_word || stopped_eos) {
+        finish_reason = "stop";
+    }
+    if (stopped_limit) {
+        finish_reason = "length";
+    }
+
+    std::time_t t = std::time(0);
+
+    json choices;
+
+    if (!finish_reason.empty()) {
+        choices = json::array({json{{"finish_reason", finish_reason},
+                                    {"index", 0},
+                                    {"delta", json::object()}}});
+    } else {
+        if (first) {
+            if (content.empty()) {
+                choices = json::array({json{{"finish_reason", nullptr},
+                                            {"index", 0},
+                                            {"delta", json{{"role", "assistant"}}}}});
+            } else {
+                // We have to send this as two updates to conform to openai behavior
+                json initial_ret = json{{"choices", json::array({json{
+                                        {"finish_reason", nullptr},
+                                        {"index", 0},
+                                        {"delta", json{
+                                            {"role", "assistant"}
+                                        }}}})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                json second_ret = json{
+                            {"choices", json::array({json{{"finish_reason", nullptr},
+                                                            {"index", 0},
+                                                            {"delta", json{
+                                                            {"content", content}}}
+                                                            }})},
+                            {"created", t},
+                            {"id", gen_chatcmplid()},
+                            {"model", modelname},
+                            {"object", "chat.completion.chunk"}};
+
+                return std::vector<json>({initial_ret, second_ret});
+            }
+        } else {
+            // Some idiosyncrasy in task processing logic makes several trailing calls
+            // with empty content, we ignore these at the calee site.
+            if (content.empty()) {
+                return std::vector<json>({json::object()});
+            }
+
+            choices = json::array({json{
+                {"finish_reason", nullptr},
+                {"index", 0},
+                {"delta",
+                json{
+                    {"content", content},
+                }},
+            }});
+        }
+    }
+
+    json ret = json{{"choices", choices},
+                    {"created", t},
+                    {"id", gen_chatcmplid()},
+                    {"model", modelname},
+                    {"object", "chat.completion.chunk"}};
+
+    return std::vector<json>({ret});
+}
index 0462fbd24739b7b214b0ca26565e835a48a89a38..39283613256eda655a85408900233654e48f69fe 100644 (file)
@@ -1,6 +1,8 @@
 #include "common.h"
 #include "llama.h"
 #include "grammar-parser.h"
+#include "utils.hpp"
+#include "oai.hpp"
 
 #include "../llava/clip.h"
 
 
 #include <cstddef>
 #include <thread>
-#include <mutex>
 #include <chrono>
 #include <condition_variable>
 #include <atomic>
 
-#ifndef SERVER_VERBOSE
-#define SERVER_VERBOSE 1
-#endif
-
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
-
 using json = nlohmann::json;
 
 struct server_params
@@ -46,197 +41,7 @@ struct server_params
     int32_t write_timeout = 600;
 };
 
-static bool server_verbose = false;
-
-#if SERVER_VERBOSE != 1
-#define LOG_VERBOSE(MSG, ...)
-#else
-#define LOG_VERBOSE(MSG, ...)                                            \
-    do                                                                   \
-    {                                                                    \
-        if (server_verbose)                                              \
-        {                                                                \
-            server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
-        }                                                                \
-    } while (0)
-#endif
-
-#define LOG_ERROR(  MSG, ...) server_log("ERROR",   __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_INFO(   MSG, ...) server_log("INFO",    __func__, __LINE__, MSG, __VA_ARGS__)
-
-json oaicompat_completion_params_parse(const json &body);
-std::string format_chatml(std::vector<json> messages);
-
-
-//
-// base64 utils (TODO: move to common in the future)
-//
-
-static const std::string base64_chars =
-             "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
-             "abcdefghijklmnopqrstuvwxyz"
-             "0123456789+/";
-
-static inline bool is_base64(uint8_t c)
-{
-    return (isalnum(c) || (c == '+') || (c == '/'));
-}
-
-static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
-{
-    int i = 0;
-    int j = 0;
-    int in_ = 0;
-
-    int in_len = encoded_string.size();
-
-    uint8_t char_array_4[4];
-    uint8_t char_array_3[3];
-
-    std::vector<uint8_t> ret;
-
-    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
-    {
-        char_array_4[i++] = encoded_string[in_]; in_++;
-        if (i == 4)
-        {
-            for (i = 0; i <4; i++)
-            {
-                char_array_4[i] = base64_chars.find(char_array_4[i]);
-            }
-
-            char_array_3[0] = ((char_array_4[0]      ) << 2) + ((char_array_4[1] & 0x30) >> 4);
-            char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
-            char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
-
-            for (i = 0; (i < 3); i++)
-            {
-                ret.push_back(char_array_3[i]);
-            }
-            i = 0;
-        }
-    }
-
-    if (i)
-    {
-        for (j = i; j <4; j++)
-        {
-            char_array_4[j] = 0;
-        }
-
-        for (j = 0; j <4; j++)
-        {
-            char_array_4[j] = base64_chars.find(char_array_4[j]);
-        }
-
-        char_array_3[0] = ((char_array_4[0]      ) << 2) + ((char_array_4[1] & 0x30) >> 4);
-        char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
-        char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
-
-        for (j = 0; (j < i - 1); j++)
-        {
-            ret.push_back(char_array_3[j]);
-        }
-    }
-
-    return ret;
-}
-
-//
-// parallel
-//
-
-enum server_state {
-    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
-    SERVER_STATE_READY,          // Server is ready and model is loaded
-    SERVER_STATE_ERROR           // An error occurred, load_model failed
-};
-
-enum task_type {
-    TASK_TYPE_COMPLETION,
-    TASK_TYPE_CANCEL,
-};
-
-struct task_server {
-    int id;
-    int target_id;
-    task_type type;
-    json data;
-    bool infill_mode = false;
-    bool embedding_mode = false;
-    int multitask_id = -1;
-};
-
-struct task_result {
-    int id;
-    int multitask_id = -1;
-    bool stop;
-    bool error;
-    json result_json;
-};
-
-struct task_multi {
-    int id;
-    std::set<int> subtasks_remaining{};
-    std::vector<task_result> results{};
-};
-
-// TODO: can become bool if we can't find use of more states
-enum slot_state
-{
-    IDLE,
-    PROCESSING,
-};
-
-enum slot_command
-{
-    NONE,
-    LOAD_PROMPT,
-    RELEASE,
-};
-
-struct slot_params
-{
-    bool stream       = true;
-    bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
-
-    uint32_t seed      = -1; // RNG seed
-    int32_t  n_keep    =  0; // number of tokens to keep from initial prompt
-    int32_t  n_predict = -1; // new tokens to predict
-
-    std::vector<std::string> antiprompt;
-
-    json input_prefix;
-    json input_suffix;
-};
-
-struct slot_image
-{
-    int32_t id;
-
-    bool request_encode_image = false;
-    float * image_embedding = nullptr;
-    int32_t image_tokens = 0;
-
-    clip_image_u8 * img_data;
-
-    std::string prefix_prompt; // before of this image
-};
-
-// completion token output with probabilities
-struct completion_token_output
-{
-    struct token_prob
-    {
-        llama_token tok;
-        float prob;
-    };
-
-    std::vector<token_prob> probs;
-    llama_token tok;
-    std::string text_to_send;
-};
+bool server_verbose = false;
 
 static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
 {
@@ -292,28 +97,6 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
     return ret;
 }
 
-static void server_log(const char *level, const char *function, int line,
-                       const char *message, const nlohmann::ordered_json &extra)
-{
-    nlohmann::ordered_json log
-    {
-        {"timestamp", time(nullptr)},
-        {"level",     level},
-        {"function",  function},
-        {"line",      line},
-        {"message",   message},
-    };
-
-    if (!extra.empty())
-    {
-        log.merge_patch(extra);
-    }
-
-    const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
-    printf("%.*s\n", (int)str.size(), str.data());
-    fflush(stdout);
-}
-
 // format incomplete utf-8 multibyte character for output
 static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
 {
@@ -355,15 +138,6 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
     return out;
 }
 
-template <typename T>
-static T json_value(const json &body, const std::string &key, const T &default_value)
-{
-    // Fallback null to default value
-    return body.contains(key) && !body.at(key).is_null()
-        ? body.value(key, default_value)
-        : default_value;
-}
-
 struct llama_client_slot
 {
     int id;
@@ -491,7 +265,7 @@ struct llama_client_slot
     }
 
     void release() {
-        if (state == IDLE || state == PROCESSING)
+        if (state == PROCESSING)
         {
             t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
             command = RELEASE;
@@ -539,7 +313,6 @@ struct llama_server_context
     bool all_slots_are_idle = false;
     bool add_bos_token      = true;
 
-    int32_t id_gen;
     int32_t n_ctx;  // total context for all clients / slots
 
     // system prompt
@@ -554,13 +327,8 @@ struct llama_server_context
     // slots / clients
     std::vector<llama_client_slot> slots;
 
-    std::vector<task_server> queue_tasks;
-    std::vector<task_result> queue_results;
-    std::vector<task_multi>  queue_multitasks;
-    std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
-    std::condition_variable condition_tasks;
-    std::mutex mutex_results;
-    std::condition_variable condition_results;
+    llama_server_queue queue_tasks;
+    llama_server_response queue_results;
 
     ~llama_server_context()
     {
@@ -619,8 +387,6 @@ struct llama_server_context
     }
 
     void initialize() {
-        id_gen = 0;
-
         // create slots
         all_slots_are_idle = true;
 
@@ -1183,39 +949,13 @@ struct llama_server_context
     void send_error(task_server& task, const std::string &error)
     {
         LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
-        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = task.id;
         res.multitask_id = task.multitask_id;
         res.stop = false;
         res.error = true;
         res.result_json = { { "content", error } };
-        queue_results.push_back(res);
-        condition_results.notify_all();
-    }
-
-    void add_multi_task(int id, std::vector<int>& sub_ids)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        task_multi multi;
-        multi.id = id;
-        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
-        queue_multitasks.push_back(multi);
-        condition_tasks.notify_one();
-    }
-
-    void update_multi_task(int multitask_id, int subtask_id, task_result& result)
-    {
-        std::lock_guard<std::mutex> lock(mutex_tasks);
-        for (auto& multitask : queue_multitasks)
-        {
-            if (multitask.id == multitask_id)
-            {
-                multitask.subtasks_remaining.erase(subtask_id);
-                multitask.results.push_back(result);
-                condition_tasks.notify_one();
-            }
-        }
+        queue_results.send(res);
     }
 
     json get_model_props()
@@ -1261,7 +1001,6 @@ struct llama_server_context
 
     void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
     {
-        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1296,13 +1035,11 @@ struct llama_server_context
             res.result_json["model"] = slot.oaicompat_model;
         }
 
-        queue_results.push_back(res);
-        condition_results.notify_all();
+        queue_results.send(res);
     }
 
     void send_final_response(llama_client_slot &slot)
     {
-        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1351,22 +1088,11 @@ struct llama_server_context
             res.result_json["model"] = slot.oaicompat_model;
         }
 
-        queue_results.push_back(res);
-        condition_results.notify_all();
-
-        // done with results, unlock
-        lock.unlock();
-
-        // parent multitask, if any, needs to be updated
-        if (slot.multitask_id != -1)
-        {
-            update_multi_task(slot.multitask_id, slot.task_id, res);
-        }
+        queue_results.send(res);
     }
 
     void send_embedding(llama_client_slot &slot)
     {
-        std::unique_lock<std::mutex> lock(mutex_results);
         task_result res;
         res.id = slot.task_id;
         res.multitask_id = slot.multitask_id;
@@ -1393,15 +1119,13 @@ struct llama_server_context
                 {"embedding", embedding },
             };
         }
-        queue_results.push_back(res);
-        condition_results.notify_all();
+        queue_results.send(res);
     }
 
-    int request_completion(json data, bool infill, bool embedding, int multitask_id)
+    void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
     {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
         task_server task;
-        task.id = id_gen++;
+        task.id = task_id;
         task.target_id = 0;
         task.data = std::move(data);
         task.infill_mode = infill;
@@ -1412,47 +1136,11 @@ struct llama_server_context
         // when a completion task's prompt array is not a singleton, we split it into multiple requests
         if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
         {
-            lock.unlock(); // entering new func scope
-            return split_multiprompt_task(task);
+            split_multiprompt_task(task_id, task);
         }
 
         // otherwise, it's a single-prompt task, we actually queue it
-        queue_tasks.push_back(task);
-        condition_tasks.notify_one();
-        return task.id;
-    }
-
-    task_result next_result(int task_id)
-    {
-        while (true)
-        {
-            std::unique_lock<std::mutex> lock(mutex_results);
-            condition_results.wait(lock, [&]{
-                return !queue_results.empty();
-            });
-
-            for (int i = 0; i < (int) queue_results.size(); i++)
-            {
-                // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
-                if (queue_results[i].multitask_id == task_id)
-                {
-                    update_multi_task(task_id, queue_results[i].id, queue_results[i]);
-                    queue_results.erase(queue_results.begin() + i);
-                    continue;
-                }
-
-                if (queue_results[i].id == task_id)
-                {
-                    assert(queue_results[i].multitask_id == -1);
-                    task_result res = queue_results[i];
-                    queue_results.erase(queue_results.begin() + i);
-                    return res;
-                }
-            }
-        }
-
-        // never reached
-        //return task_result{-1, false, false, {}};
+        queue_tasks.post(task);
     }
 
     // for multiple images processing
@@ -1525,150 +1213,117 @@ struct llama_server_context
 
     void request_cancel(int task_id)
     {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
         task_server task;
-        task.id = id_gen++;
         task.type = TASK_TYPE_CANCEL;
         task.target_id = task_id;
-        queue_tasks.push_back(task);
-        condition_tasks.notify_one();
+        queue_tasks.post(task);
     }
 
-    int split_multiprompt_task(task_server& multiprompt_task)
+    void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
     {
         int prompt_count = multiprompt_task.data.at("prompt").size();
         assert(prompt_count > 1);
 
-        int multitask_id = id_gen++;
+        // generate all the ID for subtask
         std::vector<int> subtask_ids(prompt_count);
         for (int i = 0; i < prompt_count; i++)
+        {
+            subtask_ids[i] = queue_tasks.get_new_id();
+        }
+
+        // queue up the multitask so we can track its subtask progression
+        queue_tasks.add_multitask(multitask_id, subtask_ids);
+
+        // add subtasks
+        for (int i = 0; i < prompt_count; i++)
         {
             json subtask_data = multiprompt_task.data;
             subtask_data["prompt"] = subtask_data["prompt"][i];
 
             // subtasks inherit everything else (infill mode, embedding mode, etc.)
-            subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
+            request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
         }
-
-        // queue up the multitask so we can track its subtask progression
-        add_multi_task(multitask_id, subtask_ids);
-        return multitask_id;
     }
 
-    void process_tasks()
+    void process_single_task(task_server& task)
     {
-        std::unique_lock<std::mutex> lock(mutex_tasks);
-        std::vector<task_server> deferred_tasks;
-        while (!queue_tasks.empty())
+        switch (task.type)
         {
-            task_server task = queue_tasks.front();
-            queue_tasks.erase(queue_tasks.begin());
-            switch (task.type)
-            {
-                case TASK_TYPE_COMPLETION: {
-                    llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
-                    if (slot == nullptr)
-                    {
-                        // if no slot is available, we defer this task for processing later
-                        deferred_tasks.push_back(task);
+            case TASK_TYPE_COMPLETION: {
+                llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
+                if (slot == nullptr)
+                {
+                    // if no slot is available, we defer this task for processing later
+                    LOG_VERBOSE("no slot is available", {});
+                    queue_tasks.defer(task);
+                    break;
+                }
+
+                if (task.data.contains("system_prompt"))
+                {
+                    if (!all_slots_are_idle) {
+                        send_error(task, "system prompt can only be updated when all slots are idle");
                         break;
                     }
+                    process_system_prompt_data(task.data["system_prompt"]);
 
-                    if (task.data.contains("system_prompt"))
+                    // reset cache_tokens for all slots
+                    for (llama_client_slot &slot : slots)
                     {
-                        if (!all_slots_are_idle) {
-                            send_error(task, "system prompt can only be updated when all slots are idle");
-                            break;
-                        }
-                        process_system_prompt_data(task.data["system_prompt"]);
-
-                        // reset cache_tokens for all slots
-                        for (llama_client_slot &slot : slots)
-                        {
-                            slot.cache_tokens.clear();
-                        }
+                        slot.cache_tokens.clear();
                     }
+                }
 
-                    slot->reset();
+                slot->reset();
 
-                    slot->infill       = task.infill_mode;
-                    slot->embedding    = task.embedding_mode;
-                    slot->task_id      = task.id;
-                    slot->multitask_id = task.multitask_id;
+                slot->infill       = task.infill_mode;
+                slot->embedding    = task.embedding_mode;
+                slot->task_id      = task.id;
+                slot->multitask_id = task.multitask_id;
 
-                    if (!launch_slot_with_data(slot, task.data))
+                if (!launch_slot_with_data(slot, task.data))
+                {
+                    // send error result
+                    send_error(task, "internal_error");
+                    break;
+                }
+            } break;
+            case TASK_TYPE_CANCEL: { // release slot linked with the task id
+                for (auto & slot : slots)
+                {
+                    if (slot.task_id == task.target_id)
                     {
-                        // send error result
-                        send_error(task, "internal_error");
+                        slot.release();
                         break;
                     }
-                } break;
-                case TASK_TYPE_CANCEL: { // release slot linked with the task id
-                    for (auto & slot : slots)
-                    {
-                        if (slot.task_id == task.target_id)
-                        {
-                            slot.release();
-                            break;
-                        }
-                    }
-                } break;
-            }
+                }
+            } break;
+            case TASK_TYPE_NEXT_RESPONSE: {
+                // do nothing
+            } break;
         }
+    }
 
-        // add all the deferred tasks back the the queue
-        for (task_server &task : deferred_tasks)
-        {
-            queue_tasks.push_back(task);
-        }
+    void on_finish_multitask(task_multi& multitask)
+    {
+        // all subtasks done == multitask is done
+        task_result result;
+        result.id = multitask.id;
+        result.stop = true;
+        result.error = false;
 
-        // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
-        std::vector<task_result> agg_results;
-        auto queue_iterator = queue_multitasks.begin();
-        while (queue_iterator != queue_multitasks.end())
+        // collect json results into one json result
+        std::vector<json> result_jsons;
+        for (auto& subres : multitask.results)
         {
-            if (queue_iterator->subtasks_remaining.empty())
-            {
-                // all subtasks done == multitask is done
-                task_result aggregate_result;
-                aggregate_result.id = queue_iterator->id;
-                aggregate_result.stop = true;
-                aggregate_result.error = false;
-
-                // collect json results into one json result
-                std::vector<json> result_jsons;
-                for (auto& subres : queue_iterator->results)
-                {
-                    result_jsons.push_back(subres.result_json);
-                    aggregate_result.error = aggregate_result.error && subres.error;
-                }
-                aggregate_result.result_json = json{ "results", result_jsons };
-
-
-                agg_results.push_back(aggregate_result);
-
-                condition_results.notify_all();
-
-                queue_iterator = queue_multitasks.erase(queue_iterator);
-            }
-            else
-            {
-                ++queue_iterator;
-            }
+            result_jsons.push_back(subres.result_json);
+            result.error = result.error && subres.error;
         }
-
-        // done with tasks, unlock
-        lock.unlock();
-
-        // copy aggregate results of complete multi-tasks to the results queue
-        std::lock_guard<std::mutex> lock_results(mutex_results);
-        queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
+        result.result_json = json{ { "results", result_jsons } };
+        queue_results.send(result);
     }
 
     bool update_slots() {
-        // attend tasks
-        process_tasks();
-
         if (system_need_update)
         {
             LOG_TEE("updating system prompt\n");
@@ -1684,10 +1339,12 @@ struct llama_server_context
                 LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
                 kv_cache_clear();
             }
-            std::unique_lock<std::mutex> lock(mutex_tasks);
-            condition_tasks.wait(lock, [&]{
-                return !queue_tasks.empty();
-            });
+            return true;
+        } else {
+            task_server task;
+            task.type = TASK_TYPE_NEXT_RESPONSE;
+            task.target_id = -1;
+            queue_tasks.post(task);
         }
 
         for (llama_client_slot &slot : slots)
@@ -1732,6 +1389,7 @@ struct llama_server_context
                 slot.t_last_used = ggml_time_us();
 
                 LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
+                queue_tasks.notify_slot_changed();
 
                 continue;
             }
@@ -1997,6 +1655,10 @@ struct llama_server_context
         }
         return true;
     }
+
+    void run_on_all_tasks_finished() {
+        update_slots();
+    }
 };
 
 static void server_print_usage(const char *argv0, const gpt_params &params,
@@ -2541,239 +2203,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
     }
 }
 
-static std::string random_string()
-{
-    static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
-
-    std::random_device rd;
-    std::mt19937 generator(rd());
-
-    std::string result(32, ' ');
-
-    for (int i = 0; i < 32; ++i) {
-        result[i] = str[generator() % str.size()];
-    }
-
-    return result;
-}
-
-static std::string gen_chatcmplid()
-{
-    std::stringstream chatcmplid;
-    chatcmplid << "chatcmpl-" << random_string();
-    return chatcmplid.str();
-}
-
-std::string format_chatml(std::vector<json> messages)
-{
-    std::ostringstream chatml_msgs;
-
-    for (auto it = messages.begin(); it != messages.end(); ++it) {
-        chatml_msgs << "<|im_start|>"
-                    << json_value(*it, "role",    std::string("user")) << '\n';
-        chatml_msgs << json_value(*it, "content", std::string(""))
-                    << "<|im_end|>\n";
-    }
-
-    chatml_msgs << "<|im_start|>assistant" << '\n';
-
-    return chatml_msgs.str();
-}
-
 /* llama.cpp completion api semantics */
-json oaicompat_completion_params_parse(
-    const json &body /* openai api json semantics */)
-{
-    json llama_params;
-
-    llama_params["__oaicompat"] = true;
-
-    // Map OpenAI parameters to llama.cpp parameters
-    //
-    // For parameters that are defined by the OpenAI documentation (e.g.
-    // temperature), we explicitly specify OpenAI's intended default; we
-    // need to do that because sometimes OpenAI disagrees with llama.cpp
-    //
-    // https://platform.openai.com/docs/api-reference/chat/create
-    llama_sampling_params default_sparams;
-    llama_params["model"]             = json_value(body, "model", std::string("unknown"));
-    llama_params["prompt"]            = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
-    llama_params["cache_prompt"]      = json_value(body, "cache_prompt", false);
-    llama_params["temperature"]       = json_value(body, "temperature", 0.0);
-    llama_params["top_k"]             = json_value(body, "top_k", default_sparams.top_k);
-    llama_params["top_p"]             = json_value(body, "top_p", 1.0);
-    llama_params["n_predict"]         = json_value(body, "max_tokens", -1);
-    llama_params["logit_bias"]        = json_value(body, "logit_bias",json::object());
-    llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
-    llama_params["presence_penalty"]  = json_value(body, "presence_penalty", 0.0);
-    llama_params["seed"]              = json_value(body, "seed", LLAMA_DEFAULT_SEED);
-    llama_params["stream"]            = json_value(body, "stream", false);
-    llama_params["mirostat"]          = json_value(body, "mirostat", default_sparams.mirostat);
-    llama_params["mirostat_tau"]      = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
-    llama_params["mirostat_eta"]      = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
-    llama_params["penalize_nl"]       = json_value(body, "penalize_nl", default_sparams.penalize_nl);
-    llama_params["typical_p"]         = json_value(body, "typical_p", default_sparams.typical_p);
-    llama_params["repeat_last_n"]     = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
-    llama_params["ignore_eos"]        = json_value(body, "ignore_eos", false);
-    llama_params["tfs_z"]             = json_value(body, "tfs_z", default_sparams.tfs_z);
-
-    if (body.count("grammar") != 0) {
-        llama_params["grammar"] = json_value(body, "grammar", json::object());
-    }
-
-    // Handle 'stop' field
-    if (body.contains("stop") && body["stop"].is_string()) {
-        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
-    } else {
-        llama_params["stop"] = json_value(body, "stop", json::array());
-    }
-
-    // Ensure there is ChatML-specific end sequence among stop words
-    llama_params["stop"].push_back("<|im_end|>");
-
-    return llama_params;
-}
-
-static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
-{
-    json result = response.result_json;
-
-    bool stopped_word        = result.count("stopped_word") != 0;
-    bool stopped_eos         = json_value(result, "stopped_eos", false);
-    int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
-    int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);
-    std::string content      = json_value(result, "content", std::string(""));
-
-    std::string finish_reason = "length";
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-
-    json choices =
-        streaming ? json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}})
-                  : json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"message", json{{"content", content},
-                                                         {"role", "assistant"}}}}});
-
-    std::time_t t = std::time(0);
-
-    json res =
-        json{{"choices", choices},
-            {"created", t},
-            {"model",
-                json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
-            {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
-            {"usage",
-                json{{"completion_tokens", num_tokens_predicted},
-                     {"prompt_tokens",     num_prompt_tokens},
-                     {"total_tokens",      num_tokens_predicted + num_prompt_tokens}}},
-            {"id", gen_chatcmplid()}};
-
-    if (server_verbose) {
-        res["__verbose"] = result;
-    }
-
-    if (result.contains("completion_probabilities")) {
-        res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
-    }
-
-    return res;
-}
-
-// return value is vector as there is one case where we might need to generate two responses
-static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
-    json result = response.result_json;
-
-    if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
-        return std::vector<json>({response.result_json});
-    }
-
-    bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
-    std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
-
-    bool stopped_word   = json_value(result, "stopped_word", false);
-    bool stopped_eos    = json_value(result, "stopped_eos", false);
-    bool stopped_limit  = json_value(result, "stopped_limit", false);
-    std::string content = json_value(result, "content", std::string(""));
-
-    std::string finish_reason;
-    if (stopped_word || stopped_eos) {
-        finish_reason = "stop";
-    }
-    if (stopped_limit) {
-        finish_reason = "length";
-    }
-
-    std::time_t t = std::time(0);
-
-    json choices;
-
-    if (!finish_reason.empty()) {
-        choices = json::array({json{{"finish_reason", finish_reason},
-                                    {"index", 0},
-                                    {"delta", json::object()}}});
-    } else {
-        if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"}
-                                        }}}})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json{
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", gen_chatcmplid()},
-                            {"model", modelname},
-                            {"object", "chat.completion.chunk"}};
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            // Some idiosyncrasy in task processing logic makes several trailing calls
-            // with empty content, we ignore these at the calee site.
-            if (content.empty()) {
-                return std::vector<json>({json::object()});
-            }
-
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json{
-                    {"content", content},
-                }},
-            }});
-        }
-    }
-
-    json ret = json{{"choices", choices},
-                    {"created", t},
-                    {"id", gen_chatcmplid()},
-                    {"model", modelname},
-                    {"object", "chat.completion.chunk"}};
-
-    return std::vector<json>({ret});
-}
-
 static json format_partial_response(
     llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
 ) {
@@ -3069,10 +2499,12 @@ int main(int argc, char **argv)
                     return;
                 }
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, false, false, -1);
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, data, false, false, -1);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
-                    task_result result = llama.next_result(task_id);
+                    task_result result = llama.queue_results.recv(task_id);
                     if (!result.error && result.stop) {
                         res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
                     }
@@ -3080,14 +2512,14 @@ int main(int argc, char **argv)
                     {
                         res.status = 404;
                         res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
-                        return;
                     }
+                    llama.queue_results.remove_waiting_task_id(task_id);
                 } else {
                     const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
                     {
                         while (true)
                         {
-                            task_result result = llama.next_result(task_id);
+                            task_result result = llama.queue_results.recv(task_id);
                             if (!result.error) {
                                 const std::string str =
                                     "data: " +
@@ -3098,6 +2530,7 @@ int main(int argc, char **argv)
                                 });
                                 if (!sink.write(str.c_str(), str.size()))
                                 {
+                                    llama.queue_results.remove_waiting_task_id(task_id);
                                     return false;
                                 }
                                 if (result.stop) {
@@ -3113,11 +2546,14 @@ int main(int argc, char **argv)
                                 });
                                 if (!sink.write(str.c_str(), str.size()))
                                 {
+                                    llama.queue_results.remove_waiting_task_id(task_id);
                                     return false;
                                 }
                                 break;
                             }
                         }
+
+                        llama.queue_results.remove_waiting_task_id(task_id);
                         sink.done();
                         return true;
                     };
@@ -3126,6 +2562,7 @@ int main(int argc, char **argv)
                     {
                         // cancel
                         llama.request_cancel(task_id);
+                        llama.queue_results.remove_waiting_task_id(task_id);
                     };
 
                     res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@@ -3162,11 +2599,13 @@ int main(int argc, char **argv)
                 }
                 json data = oaicompat_completion_params_parse(json::parse(req.body));
 
-                const int task_id = llama.request_completion(data, false, false, -1);
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, data, false, false, -1);
 
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
-                    task_result result = llama.next_result(task_id);
+                    task_result result = llama.queue_results.recv(task_id);
 
                     if (!result.error && result.stop) {
                         json oaicompat_result = format_final_response_oaicompat(data, result);
@@ -3177,12 +2616,12 @@ int main(int argc, char **argv)
                     } else {
                         res.status = 500;
                         res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
-                        return;
                     }
+                    llama.queue_results.remove_waiting_task_id(task_id);
                 } else {
                     const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
                         while (true) {
-                            task_result llama_result = llama.next_result(task_id);
+                            task_result llama_result = llama.queue_results.recv(task_id);
                             if (!llama_result.error) {
                                 std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
 
@@ -3195,6 +2634,7 @@ int main(int argc, char **argv)
                                             "\n\n";
                                         LOG_VERBOSE("data stream", {{"to_send", str}});
                                         if (!sink.write(str.c_str(), str.size())) {
+                                            llama.queue_results.remove_waiting_task_id(task_id);
                                             return false;
                                         }
                                     }
@@ -3210,18 +2650,21 @@ int main(int argc, char **argv)
                                     "\n\n";
                                 LOG_VERBOSE("data stream", {{"to_send", str}});
                                 if (!sink.write(str.c_str(), str.size())) {
+                                    llama.queue_results.remove_waiting_task_id(task_id);
                                     return false;
                                 }
                                 break;
                             }
                         }
                         sink.done();
+                        llama.queue_results.remove_waiting_task_id(task_id);
                         return true;
                     };
 
                     auto on_complete = [task_id, &llama](bool) {
                         // cancel request
                         llama.request_cancel(task_id);
+                        llama.queue_results.remove_waiting_task_id(task_id);
                     };
 
                     res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@@ -3235,10 +2678,12 @@ int main(int argc, char **argv)
                     return;
                 }
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, true, false, -1);
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, data, true, false, -1);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
-                    task_result result = llama.next_result(task_id);
+                    task_result result = llama.queue_results.recv(task_id);
                     if (!result.error && result.stop)
                     {
                         res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
@@ -3247,13 +2692,13 @@ int main(int argc, char **argv)
                     {
                         res.status = 404;
                         res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
-                        return;
                     }
+                    llama.queue_results.remove_waiting_task_id(task_id);
                 } else {
                     const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
                         while (true)
                         {
-                            task_result result = llama.next_result(task_id);
+                            task_result result = llama.queue_results.recv(task_id);
                             if (!result.error) {
                                 const std::string str =
                                 "data: " +
@@ -3264,6 +2709,7 @@ int main(int argc, char **argv)
                                 });
                                 if (!sink.write(str.c_str(), str.size()))
                                 {
+                                    llama.queue_results.remove_waiting_task_id(task_id);
                                     return false;
                                 }
                                 if (result.stop)
@@ -3277,8 +2723,8 @@ int main(int argc, char **argv)
                             }
                         }
 
+                        llama.queue_results.remove_waiting_task_id(task_id);
                         sink.done();
-
                         return true;
                     };
 
@@ -3352,23 +2798,46 @@ int main(int argc, char **argv)
                     image_data = "";
                 }
 
-                const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
-                task_result result = llama.next_result(task_id);
+                // create and queue the task
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
+
+                // get the result
+                task_result result = llama.queue_results.recv(task_id);
+                llama.queue_results.remove_waiting_task_id(task_id);
+
+                // send the result
                 return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
             });
 
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
     //     "Bus error: 10" - this is on macOS, it does not crash on Linux
     //std::thread t2([&]()
-    {
+    /*{
         bool running = true;
         while (running)
         {
             running = llama.update_slots();
         }
-    }
+    }*/
     //);
 
+    llama.queue_tasks.on_new_task(std::bind(
+        &llama_server_context::process_single_task, &llama, std::placeholders::_1));
+    llama.queue_tasks.on_finish_multitask(std::bind(
+        &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
+    llama.queue_tasks.on_all_tasks_finished(std::bind(
+        &llama_server_context::run_on_all_tasks_finished, &llama));
+    llama.queue_results.on_multitask_update(std::bind(
+        &llama_server_queue::update_multitask,
+        &llama.queue_tasks,
+        std::placeholders::_1,
+        std::placeholders::_2,
+        std::placeholders::_3
+    ));
+    llama.queue_tasks.start_loop();
+
     t.join();
 
     llama_backend_free();
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
new file mode 100644 (file)
index 0000000..e2b6065
--- /dev/null
@@ -0,0 +1,507 @@
+#pragma once
+
+#include <string>
+#include <vector>
+#include <set>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+
+#include "json.hpp"
+
+#include "../llava/clip.h"
+
+using json = nlohmann::json;
+
+extern bool server_verbose;
+
+#ifndef SERVER_VERBOSE
+#define SERVER_VERBOSE 1
+#endif
+
+#if SERVER_VERBOSE != 1
+#define LOG_VERBOSE(MSG, ...)
+#else
+#define LOG_VERBOSE(MSG, ...)                                            \
+    do                                                                   \
+    {                                                                    \
+        if (server_verbose)                                              \
+        {                                                                \
+            server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
+        }                                                                \
+    } while (0)
+#endif
+
+#define LOG_ERROR(  MSG, ...) server_log("ERROR",   __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_INFO(   MSG, ...) server_log("INFO",    __func__, __LINE__, MSG, __VA_ARGS__)
+
+//
+// parallel
+//
+
+enum server_state {
+    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
+    SERVER_STATE_READY,          // Server is ready and model is loaded
+    SERVER_STATE_ERROR           // An error occurred, load_model failed
+};
+
+enum task_type {
+    TASK_TYPE_COMPLETION,
+    TASK_TYPE_CANCEL,
+    TASK_TYPE_NEXT_RESPONSE
+};
+
+struct task_server {
+    int id = -1; // to be filled by llama_server_queue
+    int target_id;
+    task_type type;
+    json data;
+    bool infill_mode = false;
+    bool embedding_mode = false;
+    int multitask_id = -1;
+};
+
+struct task_result {
+    int id;
+    int multitask_id = -1;
+    bool stop;
+    bool error;
+    json result_json;
+};
+
+struct task_multi {
+    int id;
+    std::set<int> subtasks_remaining{};
+    std::vector<task_result> results{};
+};
+
+// TODO: can become bool if we can't find use of more states
+enum slot_state
+{
+    IDLE,
+    PROCESSING,
+};
+
+enum slot_command
+{
+    NONE,
+    LOAD_PROMPT,
+    RELEASE,
+};
+
+struct slot_params
+{
+    bool stream       = true;
+    bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
+
+    uint32_t seed      = -1; // RNG seed
+    int32_t  n_keep    =  0; // number of tokens to keep from initial prompt
+    int32_t  n_predict = -1; // new tokens to predict
+
+    std::vector<std::string> antiprompt;
+
+    json input_prefix;
+    json input_suffix;
+};
+
+struct slot_image
+{
+    int32_t id;
+
+    bool request_encode_image = false;
+    float * image_embedding = nullptr;
+    int32_t image_tokens = 0;
+
+    clip_image_u8 * img_data;
+
+    std::string prefix_prompt; // before of this image
+};
+
+// completion token output with probabilities
+struct completion_token_output
+{
+    struct token_prob
+    {
+        llama_token tok;
+        float prob;
+    };
+
+    std::vector<token_prob> probs;
+    llama_token tok;
+    std::string text_to_send;
+};
+
+static inline void server_log(const char *level, const char *function, int line,
+                       const char *message, const nlohmann::ordered_json &extra)
+{
+    nlohmann::ordered_json log
+    {
+        {"timestamp", time(nullptr)},
+        {"level",     level},
+        {"function",  function},
+        {"line",      line},
+        {"message",   message},
+    };
+
+    if (!extra.empty())
+    {
+        log.merge_patch(extra);
+    }
+
+    const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
+    printf("%.*s\n", (int)str.size(), str.data());
+    fflush(stdout);
+}
+
+//
+// server utils
+//
+
+template <typename T>
+static T json_value(const json &body, const std::string &key, const T &default_value)
+{
+    // Fallback null to default value
+    return body.contains(key) && !body.at(key).is_null()
+        ? body.value(key, default_value)
+        : default_value;
+}
+
+inline std::string format_chatml(std::vector<json> messages)
+{
+    std::ostringstream chatml_msgs;
+
+    for (auto it = messages.begin(); it != messages.end(); ++it) {
+        chatml_msgs << "<|im_start|>"
+                    << json_value(*it, "role",    std::string("user")) << '\n';
+        chatml_msgs << json_value(*it, "content", std::string(""))
+                    << "<|im_end|>\n";
+    }
+
+    chatml_msgs << "<|im_start|>assistant" << '\n';
+
+    return chatml_msgs.str();
+}
+
+//
+// work queue utils
+//
+
+struct llama_server_queue {
+    int id = 0;
+    std::mutex mutex_tasks;
+    // queues
+    std::vector<task_server> queue_tasks;
+    std::vector<task_server> queue_tasks_deferred;
+    std::vector<task_multi> queue_multitasks;
+    std::condition_variable condition_tasks;
+    // callback functions
+    std::function<void(task_server&)> callback_new_task;
+    std::function<void(task_multi&)> callback_finish_multitask;
+    std::function<void(void)> callback_all_task_finished;
+
+    // Add a new task to the end of the queue
+    int post(task_server task) {
+        std::unique_lock<std::mutex> lock(mutex_tasks);
+        if (task.id == -1) {
+            task.id = id++;
+        }
+        queue_tasks.push_back(std::move(task));
+        condition_tasks.notify_one();
+        return task.id;
+    }
+
+    // Add a new task, but defer until one slot is available
+    void defer(task_server task) {
+        std::unique_lock<std::mutex> lock(mutex_tasks);
+        queue_tasks_deferred.push_back(std::move(task));
+    }
+
+    // Get the next id for creating anew task
+    int get_new_id() {
+        std::unique_lock<std::mutex> lock(mutex_tasks);
+        return id++;
+    }
+
+    // Register function to process a new task
+    void on_new_task(std::function<void(task_server&)> callback) {
+        callback_new_task = callback;
+    }
+
+    // Register function to process a multitask
+    void on_finish_multitask(std::function<void(task_multi&)> callback) {
+        callback_finish_multitask = callback;
+    }
+
+    // Register the function to be called when the batch of tasks is finished
+    void on_all_tasks_finished(std::function<void(void)> callback) {
+        callback_all_task_finished = callback;
+    }
+
+    // Call when the state of one slot is changed
+    void notify_slot_changed() {
+        // move deferred tasks back to main loop
+        std::unique_lock<std::mutex> lock(mutex_tasks);
+        for (auto & task : queue_tasks_deferred) {
+            queue_tasks.push_back(std::move(task));
+        }
+        queue_tasks_deferred.clear();
+    }
+
+    // Start the main loop. This call is blocking
+    void start_loop() {
+        while (true) {
+            // new task arrived
+            LOG_VERBOSE("have new task", {});
+            {
+                while (true)
+                {
+                    std::unique_lock<std::mutex> lock(mutex_tasks);
+                    if (queue_tasks.empty()) {
+                        lock.unlock();
+                        break;
+                    }
+                    task_server task = queue_tasks.front();
+                    queue_tasks.erase(queue_tasks.begin());
+                    lock.unlock();
+                    LOG_VERBOSE("callback_new_task", {});
+                    callback_new_task(task);
+                }
+                LOG_VERBOSE("callback_all_task_finished", {});
+                // process and update all the multitasks
+                auto queue_iterator = queue_multitasks.begin();
+                while (queue_iterator != queue_multitasks.end())
+                {
+                    if (queue_iterator->subtasks_remaining.empty())
+                    {
+                        // all subtasks done == multitask is done
+                        task_multi current_multitask = *queue_iterator;
+                        callback_finish_multitask(current_multitask);
+                        // remove this multitask
+                        queue_iterator = queue_multitasks.erase(queue_iterator);
+                    }
+                    else
+                    {
+                        ++queue_iterator;
+                    }
+                }
+                // all tasks in the current loop is finished
+                callback_all_task_finished();
+            }
+            LOG_VERBOSE("wait for new task", {});
+            // wait for new task
+            {
+                std::unique_lock<std::mutex> lock(mutex_tasks);
+                if (queue_tasks.empty()) {
+                    condition_tasks.wait(lock, [&]{
+                        return !queue_tasks.empty();
+                    });
+                }
+            }
+        }
+    }
+
+    //
+    // functions to manage multitasks
+    //
+
+    // add a multitask by specifying the id of all subtask (subtask is a task_server)
+    void add_multitask(int multitask_id, std::vector<int>& sub_ids)
+    {
+        std::lock_guard<std::mutex> lock(mutex_tasks);
+        task_multi multi;
+        multi.id = multitask_id;
+        std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
+        queue_multitasks.push_back(multi);
+    }
+
+    // updatethe remaining subtasks, while appending results to multitask
+    void update_multitask(int multitask_id, int subtask_id, task_result& result)
+    {
+        std::lock_guard<std::mutex> lock(mutex_tasks);
+        for (auto& multitask : queue_multitasks)
+        {
+            if (multitask.id == multitask_id)
+            {
+                multitask.subtasks_remaining.erase(subtask_id);
+                multitask.results.push_back(result);
+            }
+        }
+    }
+};
+
+struct llama_server_response {
+    typedef std::function<void(int, int, task_result&)> callback_multitask_t;
+    callback_multitask_t callback_update_multitask;
+    // for keeping track of all tasks waiting for the result
+    std::set<int> waiting_task_ids;
+    // the main result queue
+    std::vector<task_result> queue_results;
+    std::mutex mutex_results;
+    std::condition_variable condition_results;
+
+    void add_waiting_task_id(int task_id) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+        waiting_task_ids.insert(task_id);
+    }
+
+    void remove_waiting_task_id(int task_id) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+        waiting_task_ids.erase(task_id);
+    }
+
+    // This function blocks the thread until there is a response for this task_id
+    task_result recv(int task_id) {
+        while (true)
+        {
+            std::unique_lock<std::mutex> lock(mutex_results);
+            condition_results.wait(lock, [&]{
+                return !queue_results.empty();
+            });
+            LOG_VERBOSE("condition_results unblock", {});
+
+            for (int i = 0; i < (int) queue_results.size(); i++)
+            {
+                if (queue_results[i].id == task_id)
+                {
+                    assert(queue_results[i].multitask_id == -1);
+                    task_result res = queue_results[i];
+                    queue_results.erase(queue_results.begin() + i);
+                    return res;
+                }
+            }
+        }
+
+        // should never reach here
+    }
+
+    // Register the function to update multitask
+    void on_multitask_update(callback_multitask_t callback) {
+        callback_update_multitask = callback;
+    }
+
+    // Send a new result to a waiting task_id
+    void send(task_result result) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+        LOG_VERBOSE("send new result", {});
+        for (auto& task_id : waiting_task_ids) {
+            // LOG_TEE("waiting task id %i \n", task_id);
+            // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
+            if (result.multitask_id == task_id)
+            {
+                LOG_VERBOSE("callback_update_multitask", {});
+                callback_update_multitask(task_id, result.id, result);
+                continue;
+            }
+
+            if (result.id == task_id)
+            {
+                LOG_VERBOSE("queue_results.push_back", {});
+                queue_results.push_back(result);
+                condition_results.notify_one();
+                return;
+            }
+        }
+    }
+};
+
+//
+// base64 utils (TODO: move to common in the future)
+//
+
+static const std::string base64_chars =
+             "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+             "abcdefghijklmnopqrstuvwxyz"
+             "0123456789+/";
+
+static inline bool is_base64(uint8_t c)
+{
+    return (isalnum(c) || (c == '+') || (c == '/'));
+}
+
+static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
+{
+    int i = 0;
+    int j = 0;
+    int in_ = 0;
+
+    int in_len = encoded_string.size();
+
+    uint8_t char_array_4[4];
+    uint8_t char_array_3[3];
+
+    std::vector<uint8_t> ret;
+
+    while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
+    {
+        char_array_4[i++] = encoded_string[in_]; in_++;
+        if (i == 4)
+        {
+            for (i = 0; i <4; i++)
+            {
+                char_array_4[i] = base64_chars.find(char_array_4[i]);
+            }
+
+            char_array_3[0] = ((char_array_4[0]      ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+            char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+            char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
+
+            for (i = 0; (i < 3); i++)
+            {
+                ret.push_back(char_array_3[i]);
+            }
+            i = 0;
+        }
+    }
+
+    if (i)
+    {
+        for (j = i; j <4; j++)
+        {
+            char_array_4[j] = 0;
+        }
+
+        for (j = 0; j <4; j++)
+        {
+            char_array_4[j] = base64_chars.find(char_array_4[j]);
+        }
+
+        char_array_3[0] = ((char_array_4[0]      ) << 2) + ((char_array_4[1] & 0x30) >> 4);
+        char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+        char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
+
+        for (j = 0; (j < i - 1); j++)
+        {
+            ret.push_back(char_array_3[j]);
+        }
+    }
+
+    return ret;
+}
+
+//
+// random string / id
+//
+
+static std::string random_string()
+{
+    static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
+
+    std::random_device rd;
+    std::mt19937 generator(rd());
+
+    std::string result(32, ' ');
+
+    for (int i = 0; i < 32; ++i) {
+        result[i] = str[generator() % str.size()];
+    }
+
+    return result;
+}
+
+static std::string gen_chatcmplid()
+{
+    std::stringstream chatcmplid;
+    chatcmplid << "chatcmpl-" << random_string();
+    return chatcmplid.str();
+}