]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : (refactor) no more json in server_task input (#10691)
authorXuan Son Nguyen <redacted>
Sat, 7 Dec 2024 19:21:09 +0000 (20:21 +0100)
committerGitHub <redacted>
Sat, 7 Dec 2024 19:21:09 +0000 (20:21 +0100)
* server : (refactor) no more json in server_task input

* add test for slots endpoint

* add tests for /props and /slots

* remove task inf_type

* fix CI by adding safe_json_to_str

* add "model_path" to /props

* update readme

examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_basic.py
examples/server/tests/unit/test_chat_completion.py
examples/server/tests/utils.py
examples/server/utils.hpp

index 0bab40a82250caf199d6d81dc545e97b2ba7b318..117c52d3f13104d73f217ac9c88ca45978e28cec 100644 (file)
@@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
     }
   },
   "total_slots": 1,
+  "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
   "chat_template": "..."
 }
 ```
 
 - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
 - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
+- `model_path` - the path to model file (same with `-m` argument)
 - `chat_template` - the model's original Jinja2 prompt template
 
 ### POST `/props`: Change server global properties.
index 1ce8fbae23d5c8732f2e03ca70c3cd93d8bf6bd1..1c21e55aaa011d7d8462ed151cfa29bd03924d26 100644 (file)
@@ -54,7 +54,10 @@ enum server_state {
 };
 
 enum server_task_type {
-    SERVER_TASK_TYPE_INFERENCE,
+    SERVER_TASK_TYPE_COMPLETION,
+    SERVER_TASK_TYPE_EMBEDDING,
+    SERVER_TASK_TYPE_RERANK,
+    SERVER_TASK_TYPE_INFILL,
     SERVER_TASK_TYPE_CANCEL,
     SERVER_TASK_TYPE_NEXT_RESPONSE,
     SERVER_TASK_TYPE_METRICS,
@@ -64,13 +67,6 @@ enum server_task_type {
     SERVER_TASK_TYPE_SET_LORA,
 };
 
-enum server_task_inf_type {
-    SERVER_TASK_INF_TYPE_COMPLETION,
-    SERVER_TASK_INF_TYPE_EMBEDDING,
-    SERVER_TASK_INF_TYPE_RERANK,
-    SERVER_TASK_INF_TYPE_INFILL,
-};
-
 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
 enum error_type {
     ERROR_TYPE_INVALID_REQUEST,
@@ -82,28 +78,6 @@ enum error_type {
     ERROR_TYPE_NOT_SUPPORTED, // custom error
 };
 
-struct server_task {
-    int id        = -1; // to be filled by server_queue
-    int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
-
-    llama_tokens prompt_tokens;
-    server_task_type type;
-
-    // TODO @ngxson : we should get rid of json type here
-    json data;
-
-    server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
-
-    // utility function
-    static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
-        std::unordered_set<int> ids(tasks.size());
-        for (size_t i = 0; i < tasks.size(); i++) {
-            ids.insert(tasks[i].id);
-        }
-        return ids;
-    }
-};
-
 struct slot_params {
     bool stream       = true;
     bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
@@ -118,6 +92,7 @@ struct slot_params {
 
     std::vector<std::string> antiprompt;
     bool timings_per_token = false;
+    bool ignore_eos = false;
 
     struct common_params_sampling sampling;
     struct common_params_speculative speculative;
@@ -167,7 +142,7 @@ struct slot_params {
             {"n_discard",                 n_discard},
             {"ignore_eos",                sampling.ignore_eos},
             {"stream",                    stream},
-            //{"logit_bias",                sampling.logit_bias},
+            {"logit_bias",                format_logit_bias(sampling.logit_bias)},
             {"n_probs",                   sampling.n_probs},
             {"min_keep",                  sampling.min_keep},
             {"grammar",                   sampling.grammar},
@@ -180,6 +155,209 @@ struct slot_params {
     }
 };
 
+struct server_task {
+    int id    = -1; // to be filled by server_queue
+    int index = -1; // used when there are multiple prompts (batch request)
+
+    server_task_type type;
+
+    // used by SERVER_TASK_TYPE_CANCEL
+    int id_target = -1;
+
+    // used by SERVER_TASK_TYPE_INFERENCE
+    slot_params  params;
+    llama_tokens prompt_tokens;
+    int id_selected_slot = -1;
+
+    // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
+    struct slot_action {
+        int slot_id;
+        std::string filename;
+        std::string filepath;
+    };
+    slot_action slot_action;
+
+    // used by SERVER_TASK_TYPE_METRICS
+    bool metrics_reset_bucket = false;
+
+    server_task(server_task_type type) : type(type) {}
+
+    static slot_params params_from_json_cmpl(
+            const llama_model * model,
+            const common_params & params_base,
+            const json & data) {
+        slot_params params;
+
+        // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
+        slot_params defaults;
+        defaults.sampling    = params_base.sampling;
+        defaults.speculative = params_base.speculative;
+
+        // enabling this will output extra debug information in the HTTP responses from the server
+        params.verbose           = params_base.verbosity > 9;
+        params.timings_per_token = json_value(data, "timings_per_token", false);
+
+        params.stream           = json_value(data, "stream",             false);
+        params.cache_prompt     = json_value(data, "cache_prompt",       true);
+        params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
+        params.n_indent         = json_value(data, "n_indent",           defaults.n_indent);
+        params.n_keep           = json_value(data, "n_keep",             defaults.n_keep);
+        params.n_discard        = json_value(data, "n_discard",          defaults.n_discard);
+      //params.t_max_prompt_ms  = json_value(data, "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
+        params.t_max_predict_ms = json_value(data, "t_max_predict_ms",   defaults.t_max_predict_ms);
+
+        params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
+        params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
+        params.sampling.min_p              = json_value(data, "min_p",              defaults.sampling.min_p);
+        params.sampling.xtc_probability    = json_value(data, "xtc_probability",    defaults.sampling.xtc_probability);
+        params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",      defaults.sampling.xtc_threshold);
+        params.sampling.typ_p              = json_value(data, "typical_p",          defaults.sampling.typ_p);
+        params.sampling.temp               = json_value(data, "temperature",        defaults.sampling.temp);
+        params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",     defaults.sampling.dynatemp_range);
+        params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",  defaults.sampling.dynatemp_exponent);
+        params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",      defaults.sampling.penalty_last_n);
+        params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",     defaults.sampling.penalty_repeat);
+        params.sampling.penalty_freq       = json_value(data, "frequency_penalty",  defaults.sampling.penalty_freq);
+        params.sampling.penalty_present    = json_value(data, "presence_penalty",   defaults.sampling.penalty_present);
+        params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",     defaults.sampling.dry_multiplier);
+        params.sampling.dry_base           = json_value(data, "dry_base",           defaults.sampling.dry_base);
+        params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+        params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+        params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
+        params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
+        params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
+        params.sampling.penalize_nl        = json_value(data, "penalize_nl",        defaults.sampling.penalize_nl);
+        params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
+        params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
+        params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
+
+        params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+        params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+        params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+        params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
+        params.speculative.n_min = std::max(params.speculative.n_min, 2);
+        params.speculative.n_max = std::max(params.speculative.n_max, 0);
+
+        if (params.sampling.dry_base < 1.0f) {
+           params.sampling.dry_base = defaults.sampling.dry_base;
+        }
+
+        // sequence breakers for DRY
+        {
+            // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+            // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+            if (data.contains("dry_sequence_breakers")) {
+                params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+                if (params.sampling.dry_sequence_breakers.empty()) {
+                    throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
+                }
+            }
+        }
+
+        // process "json_schema" and "grammar"
+        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
+            throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
+        }
+        if (data.contains("json_schema") && !data.contains("grammar")) {
+            try {
+                auto schema                  = json_value(data, "json_schema", json::object());
+                params.sampling.grammar = json_schema_to_grammar(schema);
+            } catch (const std::exception & e) {
+                throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
+            }
+        } else {
+            params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+        }
+
+        {
+            params.sampling.logit_bias.clear();
+            params.ignore_eos = json_value(data, "ignore_eos", false);
+
+            const auto & logit_bias = data.find("logit_bias");
+            if (logit_bias != data.end() && logit_bias->is_array()) {
+                const int n_vocab = llama_n_vocab(model);
+                for (const auto & el : *logit_bias) {
+                    // TODO: we may want to throw errors here, in case "el" is incorrect
+                    if (el.is_array() && el.size() == 2) {
+                        float bias;
+                        if (el[1].is_number()) {
+                            bias = el[1].get<float>();
+                        } else if (el[1].is_boolean() && !el[1].get<bool>()) {
+                            bias = -INFINITY;
+                        } else {
+                            continue;
+                        }
+
+                        if (el[0].is_number_integer()) {
+                            llama_token tok = el[0].get<llama_token>();
+                            if (tok >= 0 && tok < n_vocab) {
+                                params.sampling.logit_bias.push_back({tok, bias});
+                            }
+                        } else if (el[0].is_string()) {
+                            auto toks = common_tokenize(model, el[0].get<std::string>(), false);
+                            for (auto tok : toks) {
+                                params.sampling.logit_bias.push_back({tok, bias});
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        {
+            params.antiprompt.clear();
+
+            const auto & stop = data.find("stop");
+            if (stop != data.end() && stop->is_array()) {
+                for (const auto & word : *stop) {
+                    if (!word.empty()) {
+                        params.antiprompt.push_back(word);
+                    }
+                }
+            }
+        }
+
+        {
+            const auto & samplers = data.find("samplers");
+            if (samplers != data.end()) {
+                if (samplers->is_array()) {
+                    std::vector<std::string> sampler_names;
+                    for (const auto & name : *samplers) {
+                        if (name.is_string()) {
+                            sampler_names.emplace_back(name);
+                        }
+                    }
+                    params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
+                } else if (samplers->is_string()){
+                    std::string sampler_string;
+                    for (const auto & name : *samplers) {
+                        sampler_string += name;
+                    }
+                    params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
+                }
+            } else {
+                params.sampling.samplers = defaults.sampling.samplers;
+            }
+        }
+
+        std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
+        params.oaicompat_model = json_value(data, "model", model_name);
+
+        return params;
+    }
+
+    // utility function
+    static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
+        std::unordered_set<int> ids(tasks.size());
+        for (size_t i = 0; i < tasks.size(); i++) {
+            ids.insert(tasks[i].id);
+        }
+        return ids;
+    }
+};
+
 struct result_timings {
     int32_t prompt_n = -1;
     double prompt_ms;
@@ -191,7 +369,7 @@ struct result_timings {
     double predicted_per_token_ms;
     double predicted_per_second;
 
-    json to_json() {
+    json to_json() const {
         return {
             {"prompt_n",               prompt_n},
             {"prompt_ms",              prompt_ms},
@@ -623,7 +801,8 @@ struct server_task_result_metrics : server_task_result {
     uint64_t n_decode_total     = 0;
     uint64_t n_busy_slots_total = 0;
 
-    // TODO: get rid of this json object and use to_json() instead
+    // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
+    // therefore, we use json to temporarily store the slot.to_json() result
     json slots_data = json::array();
 
     virtual json to_json() override {
@@ -708,6 +887,9 @@ struct server_slot {
     int id;
     int id_task = -1;
 
+    // only used for completion/embedding/infill/rerank
+    server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
+
     llama_batch batch_spec = {};
 
     llama_context * ctx = nullptr;
@@ -746,8 +928,6 @@ struct server_slot {
     llama_tokens cache_tokens;
     std::vector<completion_token_output> generated_token_probs;
 
-    server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
-
     bool has_next_token = true;
     bool has_new_line   = false;
     bool truncated      = false;
@@ -787,11 +967,15 @@ struct server_slot {
         n_past             = 0;
         n_sent_text        = 0;
         n_sent_token_probs = 0;
-        inf_type           = SERVER_TASK_INF_TYPE_COMPLETION;
+        task_type          = SERVER_TASK_TYPE_COMPLETION;
 
         generated_token_probs.clear();
     }
 
+    bool is_non_causal() const {
+        return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
+    }
+
     bool has_budget(const common_params & global_params) {
         if (params.n_predict == -1 && global_params.n_predict == -1) {
             return true; // limitless
@@ -903,6 +1087,7 @@ struct server_slot {
             {"n_ctx",         n_ctx},
             {"speculative",   can_speculate()},
             {"is_processing", is_processing()},
+            {"non_causal",    is_non_causal()},
             {"params",        params.to_json()},
             {"prompt",        common_detokenize(ctx, prompt_tokens)},
             {"next_token",
@@ -988,9 +1173,7 @@ struct server_queue {
     // Add a new task to the end of the queue
     int post(server_task task, bool front = false) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
-        if (task.id == -1) {
-            task.id = id++;
-        }
+        GGML_ASSERT(task.id != -1);
         QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
         if (front) {
             queue_tasks.push_front(std::move(task));
@@ -1468,104 +1651,14 @@ struct server_context {
     }
 
     bool launch_slot_with_task(server_slot & slot, const server_task & task) {
-        // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
-        slot_params defaults;
-        defaults.sampling    = params_base.sampling;
-        defaults.speculative = params_base.speculative;
+        slot.reset();
+        slot.id_task       = task.id;
+        slot.index         = task.index;
+        slot.task_type     = task.type;
+        slot.params        = std::move(task.params);
+        slot.prompt_tokens = std::move(task.prompt_tokens);
 
-        const auto & data = task.data;
-
-        if (data.count("__oaicompat") != 0) {
-            std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
-            slot.params.oaicompat         = true;
-            slot.params.oaicompat_chat    = json_value(data, "__oaicompat_chat", false);
-            slot.params.oaicompat_model   = json_value(data, "model", model_name);
-            slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
-        } else {
-            slot.params.oaicompat         = false;
-        }
-
-
-        // enabling this will output extra debug information in the HTTP responses from the server
-        slot.params.verbose           = params_base.verbosity > 9;
-        slot.params.timings_per_token = json_value(data, "timings_per_token", false);
-
-        slot.params.stream           = json_value(data, "stream",             false);
-        slot.params.cache_prompt     = json_value(data, "cache_prompt",       true);
-        slot.params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
-        slot.params.n_indent         = json_value(data, "n_indent",           defaults.n_indent);
-        slot.params.n_keep           = json_value(data, "n_keep",             defaults.n_keep);
-        slot.params.n_discard        = json_value(data, "n_discard",          defaults.n_discard);
-      //slot.params.t_max_prompt_ms  = json_value(data, "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
-        slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms",   defaults.t_max_predict_ms);
-
-        slot.params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
-        slot.params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
-        slot.params.sampling.min_p              = json_value(data, "min_p",              defaults.sampling.min_p);
-        slot.params.sampling.xtc_probability    = json_value(data, "xtc_probability",    defaults.sampling.xtc_probability);
-        slot.params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",      defaults.sampling.xtc_threshold);
-        slot.params.sampling.typ_p              = json_value(data, "typical_p",          defaults.sampling.typ_p);
-        slot.params.sampling.temp               = json_value(data, "temperature",        defaults.sampling.temp);
-        slot.params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",     defaults.sampling.dynatemp_range);
-        slot.params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",  defaults.sampling.dynatemp_exponent);
-        slot.params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",      defaults.sampling.penalty_last_n);
-        slot.params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",     defaults.sampling.penalty_repeat);
-        slot.params.sampling.penalty_freq       = json_value(data, "frequency_penalty",  defaults.sampling.penalty_freq);
-        slot.params.sampling.penalty_present    = json_value(data, "presence_penalty",   defaults.sampling.penalty_present);
-        slot.params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",     defaults.sampling.dry_multiplier);
-        slot.params.sampling.dry_base           = json_value(data, "dry_base",           defaults.sampling.dry_base);
-        slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
-        slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
-        slot.params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
-        slot.params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
-        slot.params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
-        slot.params.sampling.penalize_nl        = json_value(data, "penalize_nl",        defaults.sampling.penalize_nl);
-        slot.params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
-        slot.params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
-        slot.params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
-
-        slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
-        slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
-        slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
-
-        slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
-        slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
-        slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
-
-        if (slot.params.sampling.dry_base < 1.0f) {
-           slot.params.sampling.dry_base = defaults.sampling.dry_base;
-        }
-
-        // sequence breakers for DRY
-        {
-            // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
-            // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
-
-            if (data.contains("dry_sequence_breakers")) {
-                slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
-                if (slot.params.sampling.dry_sequence_breakers.empty()) {
-                    send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
-                    return false;
-                }
-            }
-        }
-
-        // process "json_schema" and "grammar"
-        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
-            send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
-            return false;
-        }
-        if (data.contains("json_schema") && !data.contains("grammar")) {
-            try {
-                auto schema                  = json_value(data, "json_schema", json::object());
-                slot.params.sampling.grammar = json_schema_to_grammar(schema);
-            } catch (const std::exception & e) {
-                send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
-                return false;
-            }
-        } else {
-            slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
-        }
+        SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
             // Might be better to reject the request with a 400 ?
@@ -1573,78 +1666,8 @@ struct server_context {
             SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
         }
 
-        {
-            slot.params.sampling.logit_bias.clear();
-
-            if (json_value(data, "ignore_eos", false) && has_eos_token) {
-                slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
-            }
-
-            const auto & logit_bias = data.find("logit_bias");
-            if (logit_bias != data.end() && logit_bias->is_array()) {
-                const int n_vocab = llama_n_vocab(model);
-                for (const auto & el : *logit_bias) {
-                    // TODO: we may want to throw errors here, in case "el" is incorrect
-                    if (el.is_array() && el.size() == 2) {
-                        float bias;
-                        if (el[1].is_number()) {
-                            bias = el[1].get<float>();
-                        } else if (el[1].is_boolean() && !el[1].get<bool>()) {
-                            bias = -INFINITY;
-                        } else {
-                            continue;
-                        }
-
-                        if (el[0].is_number_integer()) {
-                            llama_token tok = el[0].get<llama_token>();
-                            if (tok >= 0 && tok < n_vocab) {
-                                slot.params.sampling.logit_bias.push_back({tok, bias});
-                            }
-                        } else if (el[0].is_string()) {
-                            auto toks = common_tokenize(model, el[0].get<std::string>(), false);
-                            for (auto tok : toks) {
-                                slot.params.sampling.logit_bias.push_back({tok, bias});
-                            }
-                        }
-                    }
-                }
-            }
-        }
-
-        {
-            slot.params.antiprompt.clear();
-
-            const auto & stop = data.find("stop");
-            if (stop != data.end() && stop->is_array()) {
-                for (const auto & word : *stop) {
-                    if (!word.empty()) {
-                        slot.params.antiprompt.push_back(word);
-                    }
-                }
-            }
-        }
-
-        {
-            const auto & samplers = data.find("samplers");
-            if (samplers != data.end()) {
-                if (samplers->is_array()) {
-                    std::vector<std::string> sampler_names;
-                    for (const auto & name : *samplers) {
-                        if (name.is_string()) {
-                            sampler_names.emplace_back(name);
-                        }
-                    }
-                    slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
-                } else if (samplers->is_string()){
-                    std::string sampler_string;
-                    for (const auto & name : *samplers) {
-                        sampler_string += name;
-                    }
-                    slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
-                }
-            } else {
-                slot.params.sampling.samplers = defaults.sampling.samplers;
-            }
+        if (slot.params.ignore_eos && has_eos_token) {
+            slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
         }
 
         {
@@ -2020,82 +2043,13 @@ struct server_context {
     // Functions to create new task(s) and receive result(s)
     //
 
-    // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
-    std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
-        std::vector<server_task> tasks;
-        auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
-            SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
-
-            server_task task;
-            task.id            = queue_tasks.get_new_id();
-            task.inf_type      = inf_type;
-            task.type          = SERVER_TASK_TYPE_INFERENCE;
-            task.data          = task_data;
-            task.prompt_tokens = std::move(prompt_tokens);
-            tasks.push_back(std::move(task));
-        };
-
-        static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
-        if (!data.contains("prompt")) {
-            throw std::runtime_error(error_msg);
-        }
-
-        // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
-        bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
-        std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
-        switch (inf_type) {
-            case SERVER_TASK_INF_TYPE_RERANK:
-                {
-                    // prompts[0] is the question
-                    // the rest are the answers/documents
-                    GGML_ASSERT(tokenized_prompts.size() > 1);
-                    SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
-                    for (size_t i = 1; i < tokenized_prompts.size(); i++) {
-                        data["index"] = i - 1;
-                        auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
-                        create_task(data, tokens);
-                    }
-                } break;
-            case SERVER_TASK_INF_TYPE_INFILL:
-                {
-                    SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
-                    for (size_t i = 0; i < tokenized_prompts.size(); i++) {
-                        data["index"] = i;
-                        auto tokens = format_infill(
-                            ctx,
-                            data.at("input_prefix"),
-                            data.at("input_suffix"),
-                            data.at("input_extra"),
-                            params_base.n_batch,
-                            params_base.n_predict,
-                            slots[0].n_ctx, // TODO: there should be a better way
-                            params_base.spm_infill,
-                            tokenized_prompts[i]
-                        );
-                        create_task(data, tokens);
-                    }
-                } break;
-            default:
-                {
-                    SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
-                    for (size_t i = 0; i < tokenized_prompts.size(); i++) {
-                        data["index"] = i;
-                        create_task(data, tokenized_prompts[i]);
-                    }
-                }
-        }
-
-        return tasks;
-    }
-
     void cancel_tasks(const std::unordered_set<int> & id_tasks) {
         std::vector<server_task> cancel_tasks;
         cancel_tasks.reserve(id_tasks.size());
         for (const auto & id_task : id_tasks) {
             SRV_WRN("cancel task, id_task = %d\n", id_task);
 
-            server_task task;
-            task.type      = SERVER_TASK_TYPE_CANCEL;
+            server_task task(SERVER_TASK_TYPE_CANCEL);
             task.id_target = id_task;
             cancel_tasks.push_back(task);
             queue_results.remove_waiting_task_id(id_task);
@@ -2104,7 +2058,7 @@ struct server_context {
         queue_tasks.post(cancel_tasks, true);
     }
 
-    // receive the results from task(s) created by create_tasks_inference
+    // receive the results from task(s)
     void receive_multi_results(
             const std::unordered_set<int> & id_tasks,
             const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
@@ -2131,7 +2085,7 @@ struct server_context {
         result_handler(results);
     }
 
-    // receive the results from task(s) created by create_tasks_inference, in stream mode
+    // receive the results from task(s), in stream mode
     void receive_cmpl_results_stream(
             const std::unordered_set<int> & id_tasks,
             const std::function<bool(server_task_result_ptr&)> & result_handler,
@@ -2166,9 +2120,12 @@ struct server_context {
 
     void process_single_task(server_task task) {
         switch (task.type) {
-            case SERVER_TASK_TYPE_INFERENCE:
+            case SERVER_TASK_TYPE_COMPLETION:
+            case SERVER_TASK_TYPE_INFILL:
+            case SERVER_TASK_TYPE_EMBEDDING:
+            case SERVER_TASK_TYPE_RERANK:
                 {
-                    const int id_slot = json_value(task.data, "id_slot", -1);
+                    const int id_slot = task.id_selected_slot;
 
                     server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
 
@@ -2185,13 +2142,6 @@ struct server_context {
                         break;
                     }
 
-                    slot->reset();
-
-                    slot->id_task       = task.id;
-                    slot->inf_type      = task.inf_type;
-                    slot->index         = json_value(task.data, "index", 0);
-                    slot->prompt_tokens = std::move(task.prompt_tokens);
-
                     if (!launch_slot_with_task(*slot, task)) {
                         SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
                         break;
@@ -2255,14 +2205,14 @@ struct server_context {
                     res->n_decode_total          = metrics.n_decode_total;
                     res->n_busy_slots_total      = metrics.n_busy_slots_total;
 
-                    if (json_value(task.data, "reset_bucket", false)) {
+                    if (task.metrics_reset_bucket) {
                         metrics.reset_bucket();
                     }
                     queue_results.send(std::move(res));
                 } break;
             case SERVER_TASK_TYPE_SLOT_SAVE:
                 {
-                    int id_slot = task.data.at("id_slot");
+                    int id_slot = task.slot_action.slot_id;
                     server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -2278,8 +2228,8 @@ struct server_context {
                     const size_t token_count = slot->cache_tokens.size();
                     const int64_t t_start = ggml_time_us();
 
-                    std::string filename = task.data.at("filename");
-                    std::string filepath = task.data.at("filepath");
+                    std::string filename = task.slot_action.filename;
+                    std::string filepath = task.slot_action.filepath;
 
                     const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
 
@@ -2298,7 +2248,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_RESTORE:
                 {
-                    int id_slot = task.data.at("id_slot");
+                    int id_slot = task.slot_action.slot_id;
                     server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -2313,8 +2263,8 @@ struct server_context {
 
                     const int64_t t_start = ggml_time_us();
 
-                    std::string filename = task.data.at("filename");
-                    std::string filepath = task.data.at("filepath");
+                    std::string filename = task.slot_action.filename;
+                    std::string filepath = task.slot_action.filepath;
 
                     slot->cache_tokens.resize(slot->n_ctx);
                     size_t token_count = 0;
@@ -2341,7 +2291,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_ERASE:
                 {
-                    int id_slot = task.data.at("id_slot");
+                    int id_slot = task.slot_action.slot_id;
                     server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -2400,10 +2350,8 @@ struct server_context {
         {
             SRV_DBG("%s", "posting NEXT_RESPONSE\n");
 
-            server_task task;
-            task.type      = SERVER_TASK_TYPE_NEXT_RESPONSE;
-            task.id_target = -1;
-
+            server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
+            task.id = queue_tasks.get_new_id();
             queue_tasks.post(task);
         }
 
@@ -2517,7 +2465,7 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+                        if (slot.is_non_causal()) {
                             if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.release();
                                 send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -2632,7 +2580,7 @@ struct server_context {
                     }
 
                     // non-causal tasks require to fit the entire prompt in the physical batch
-                    if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+                    if (slot.is_non_causal()) {
                         // cannot fit the prompt in the current batch - will try next iter
                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
                             continue;
@@ -2640,10 +2588,7 @@ struct server_context {
                     }
 
                     // check that we are in the right batch_type, if not defer the slot
-                    const bool slot_type =
-                        slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
-                        slot.inf_type == SERVER_TASK_INF_TYPE_RERANK     ? 1 : 0;
-
+                    int slot_type = slot.is_non_causal();
                     if (batch_type == -1) {
                         batch_type = slot_type;
                     } else if (batch_type != slot_type) {
@@ -2760,7 +2705,7 @@ struct server_context {
                 }
 
                 if (slot.state == SLOT_STATE_DONE_PROMPT) {
-                    if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
+                    if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
                         // prompt evaluated for embedding
                         send_embedding(slot, batch_view);
                         slot.release();
@@ -2768,7 +2713,7 @@ struct server_context {
                         continue; // continue loop of slots
                     }
 
-                    if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
+                    if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
                         send_rerank(slot, batch_view);
                         slot.release();
                         slot.i_batch = -1;
@@ -2998,12 +2943,12 @@ int main(int argc, char ** argv) {
 
     auto res_error = [](httplib::Response & res, const json & error_data) {
         json final_response {{"error", error_data}};
-        res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
+        res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
         res.status = json_value(error_data, "code", 500);
     };
 
     auto res_ok = [](httplib::Response & res, const json & data) {
-        res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
+        res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
         res.status = 200;
     };
 
@@ -3140,10 +3085,8 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task;
+        server_task task(SERVER_TASK_TYPE_METRICS);
         task.id = ctx_server.queue_tasks.get_new_id();
-        task.type = SERVER_TASK_TYPE_METRICS;
-
         ctx_server.queue_results.add_waiting_task_id(task.id);
         ctx_server.queue_tasks.post(task, true); // high-priority task
 
@@ -3178,11 +3121,9 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task;
+        server_task task(SERVER_TASK_TYPE_METRICS);
         task.id = ctx_server.queue_tasks.get_new_id();
-        task.id_target = -1;
-        task.type = SERVER_TASK_TYPE_METRICS;
-        task.data.push_back({{"reset_bucket", true}});
+        task.metrics_reset_bucket = true;
 
         ctx_server.queue_results.add_waiting_task_id(task.id);
         ctx_server.queue_tasks.post(task, true); // high-priority task
@@ -3286,19 +3227,17 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task;
-        task.type = SERVER_TASK_TYPE_SLOT_SAVE;
-        task.data = {
-            { "id_slot", id_slot },
-            { "filename", filename },
-            { "filepath", filepath },
-        };
+        server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+        task.id = ctx_server.queue_tasks.get_new_id();
+        task.slot_action.slot_id  = id_slot;
+        task.slot_action.filename = filename;
+        task.slot_action.filepath = filepath;
 
-        const int id_task = ctx_server.queue_tasks.post(task);
-        ctx_server.queue_results.add_waiting_task_id(id_task);
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3317,19 +3256,17 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task;
-        task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
-        task.data = {
-            { "id_slot", id_slot },
-            { "filename", filename },
-            { "filepath", filepath },
-        };
+        server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+        task.id = ctx_server.queue_tasks.get_new_id();
+        task.slot_action.slot_id  = id_slot;
+        task.slot_action.filename = filename;
+        task.slot_action.filepath = filepath;
 
-        const int id_task = ctx_server.queue_tasks.post(task);
-        ctx_server.queue_results.add_waiting_task_id(id_task);
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3341,17 +3278,15 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
-        server_task task;
-        task.type = SERVER_TASK_TYPE_SLOT_ERASE;
-        task.data = {
-            { "id_slot", id_slot },
-        };
+        server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+        task.id = ctx_server.queue_tasks.get_new_id();
+        task.slot_action.slot_id = id_slot;
 
-        const int id_task = ctx_server.queue_tasks.post(task);
-        ctx_server.queue_results.add_waiting_task_id(id_task);
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3392,9 +3327,11 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
+        // this endpoint is publicly available, please only return what is safe to be exposed
         json data = {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
+            { "model_path",                  ctx_server.params_base.model },
             { "chat_template",               llama_get_chat_template(ctx_server.model) },
         };
 
@@ -3417,17 +3354,47 @@ int main(int argc, char ** argv) {
     // handle completion-like requests (completion, chat, infill)
     // we can optionally provide a custom format for partial results and final results
     const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
-            server_task_inf_type inf_type,
+            server_task_type type,
             json & data,
             httplib::Response & res,
-            bool oai_compat = false) {
+            bool oaicompat = false,
+            bool oaicompat_chat = false) {
+        GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
+
         if (ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
-        data["completion_id"] = gen_chatcmplid();
-        std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
+        auto completion_id = gen_chatcmplid();
+        std::vector<server_task> tasks;
+
+        try {
+            std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
+            tasks.reserve(tokenized_prompts.size());
+            for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+                server_task task = server_task(type);
+
+                task.id    = ctx_server.queue_tasks.get_new_id();
+                task.index = i;
+
+                task.prompt_tokens    = std::move(tokenized_prompts[i]);
+                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
+                task.id_selected_slot = json_value(data, "id_slot", -1);
+
+                // OAI-compat
+                task.params.oaicompat           = oaicompat;
+                task.params.oaicompat_chat      = oaicompat_chat;
+                task.params.oaicompat_cmpl_id   = completion_id;
+                // oaicompat_model is already populated by params_from_json_cmpl
+
+                tasks.push_back(task);
+            }
+        } catch (const std::exception & e) {
+            res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
+            return;
+        }
+
         ctx_server.queue_results.add_waiting_tasks(tasks);
         ctx_server.queue_tasks.post(tasks);
 
@@ -3453,7 +3420,7 @@ int main(int argc, char ** argv) {
 
             ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         } else {
-            const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
+            const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
                 ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
                     json res_json = result->to_json();
                     if (res_json.is_array()) {
@@ -3469,7 +3436,7 @@ int main(int argc, char ** argv) {
                 }, [&](const json & error_data) {
                     server_sent_event(sink, "error", error_data);
                 });
-                if (oai_compat) {
+                if (oaicompat) {
                     static const std::string ev_done = "data: [DONE]\n\n";
                     sink.write(ev_done.data(), ev_done.size());
                 }
@@ -3487,7 +3454,12 @@ int main(int argc, char ** argv) {
 
     const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
         json data = json::parse(req.body);
-        return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
+        return handle_completions_generic(
+            SERVER_TASK_TYPE_COMPLETION,
+            data,
+            res,
+            /* oaicompat */ false,
+            /* oaicompat_chat */ false);
     };
 
     const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@@ -3537,7 +3509,7 @@ int main(int argc, char ** argv) {
         }
         data["input_extra"] = input_extra; // default to empty array if it's not exist
 
-        return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
+        return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
     };
 
     const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@@ -3547,11 +3519,15 @@ int main(int argc, char ** argv) {
         }
 
         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
-        data["__oaicompat_chat"] = true;
-        return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
+        return handle_completions_generic(
+            SERVER_TASK_TYPE_COMPLETION,
+            data,
+            res,
+            /* oaicompat */ true,
+            /* oaicompat_chat */ true);
     };
 
-    const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
+    const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
         json models = {
             {"object", "list"},
             {"data", {
@@ -3565,7 +3541,7 @@ int main(int argc, char ** argv) {
              }}
         };
 
-        res.set_content(models.dump(), MIMETYPE_JSON);
+        res_ok(res, models);
     };
 
     const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
@@ -3642,7 +3618,16 @@ int main(int argc, char ** argv) {
         json responses = json::array();
         bool error = false;
         {
-            std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
+            std::vector<server_task> tasks;
+            std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
+            for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+                server_task task   = server_task(SERVER_TASK_TYPE_EMBEDDING);
+                task.id            = ctx_server.queue_tasks.get_new_id();
+                task.index         = i;
+                task.prompt_tokens = std::move(tokenized_prompts[i]);
+                tasks.push_back(task);
+            }
+
             ctx_server.queue_results.add_waiting_tasks(tasks);
             ctx_server.queue_tasks.post(tasks);
 
@@ -3669,7 +3654,7 @@ int main(int argc, char ** argv) {
         // write JSON response
         json root = oaicompat
             ? format_embeddings_response_oaicompat(body, responses)
-            : responses[0];
+            : responses.size() == 1 ? responses[0] : json(responses);
         res_ok(res, root);
     };
 
@@ -3708,20 +3693,23 @@ int main(int argc, char ** argv) {
             return;
         }
 
-        // construct prompt object: array of ["query", "doc0", "doc1", ...]
-        json prompt;
-        prompt.push_back(query);
-        for (const auto & doc : documents) {
-            prompt.push_back(doc);
-        }
-
-        LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
+        llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
 
         // create and queue the task
         json responses = json::array();
         bool error = false;
         {
-            std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
+            std::vector<server_task> tasks;
+            std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
+            tasks.reserve(tokenized_docs.size());
+            for (size_t i = 0; i < tokenized_docs.size(); i++) {
+                server_task task   = server_task(SERVER_TASK_TYPE_RERANK);
+                task.id            = ctx_server.queue_tasks.get_new_id();
+                task.index         = i;
+                task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
+                tasks.push_back(task);
+            }
+
             ctx_server.queue_results.add_waiting_tasks(tasks);
             ctx_server.queue_tasks.post(tasks);
 
@@ -3782,13 +3770,13 @@ int main(int argc, char ** argv) {
             }
         }
 
-        server_task task;
-        task.type = SERVER_TASK_TYPE_SET_LORA;
-        const int id_task = ctx_server.queue_tasks.post(task);
-        ctx_server.queue_results.add_waiting_task_id(id_task);
+        server_task task(SERVER_TASK_TYPE_SET_LORA);
+        task.id = ctx_server.queue_tasks.get_new_id();
+        ctx_server.queue_results.add_waiting_task_id(task.id);
+        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(id_task);
-        ctx_server.queue_results.remove_waiting_task_id(id_task);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
+        ctx_server.queue_results.remove_waiting_task_id(task.id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
index d82d54a5a6f47efe00c08f497c66b6ccbe7ffcf3..1d512401685fb37acac73117d29acad7b4812c60 100644 (file)
@@ -22,7 +22,12 @@ def test_server_props():
     server.start()
     res = server.make_request("GET", "/props")
     assert res.status_code == 200
+    assert ".gguf" in res.body["model_path"]
     assert res.body["total_slots"] == server.n_slots
+    default_val = res.body["default_generation_settings"]
+    assert server.n_ctx is not None and server.n_slots is not None
+    assert default_val["n_ctx"] == server.n_ctx / server.n_slots
+    assert default_val["params"]["seed"] == server.seed
 
 
 def test_server_models():
@@ -33,6 +38,31 @@ def test_server_models():
     assert len(res.body["data"]) == 1
     assert res.body["data"][0]["id"] == server.model_alias
 
+
+def test_server_slots():
+    global server
+
+    # without slots endpoint enabled, this should return error
+    server.server_slots = False
+    server.start()
+    res = server.make_request("GET", "/slots")
+    assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
+    assert "error" in res.body
+    server.stop()
+
+    # with slots endpoint enabled, this should return slots info
+    server.server_slots = True
+    server.n_slots = 2
+    server.start()
+    res = server.make_request("GET", "/slots")
+    assert res.status_code == 200
+    assert len(res.body) == server.n_slots
+    assert server.n_ctx is not None and server.n_slots is not None
+    assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
+    assert "params" in res.body[0]
+    assert res.body[0]["params"]["seed"] == server.seed
+
+
 def test_load_split_model():
     global server
     server.model_hf_repo = "ggml-org/models"
index f13c6c4ca4bd37d24ba78e87b590877f4d76c585..6573cc17f7b87843bdc389ee2e0f95065fcf1e43 100644 (file)
@@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
         ],
     })
     assert res.status_code == 200
+    assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
     assert res.body["model"] == model if model is not None else server.model_alias
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["completion_tokens"] == n_predicted
@@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
         "stream": True,
     })
     content = ""
+    last_cmpl_id = None
     for data in res:
         choice = data["choices"][0]
         assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
+        if last_cmpl_id is None:
+            last_cmpl_id = data["id"]
+        assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
         if choice["finish_reason"] in ["stop", "length"]:
             assert data["usage"]["prompt_tokens"] == n_prompt
             assert data["usage"]["completion_tokens"] == n_predicted
index e17a05ff6902a271778c4844689675030b1e04e5..69215eaa4ebb70eb087c04191ff97a0ee1ae0f35 100644 (file)
@@ -64,6 +64,7 @@ class ServerProcess:
     server_embeddings: bool | None = False
     server_reranking: bool | None = False
     server_metrics: bool | None = False
+    server_slots: bool | None = False
     draft: int | None = None
     api_key: str | None = None
     response_format: str | None = None
@@ -91,7 +92,6 @@ class ServerProcess:
         else:
             server_path = "../../../build/bin/llama-server"
         server_args = [
-            "--slots",  # requires to get slot status via /slots endpoint
             "--host",
             self.server_host,
             "--port",
@@ -129,6 +129,8 @@ class ServerProcess:
             server_args.append("--reranking")
         if self.server_metrics:
             server_args.append("--metrics")
+        if self.server_slots:
+            server_args.append("--slots")
         if self.model_alias:
             server_args.extend(["--alias", self.model_alias])
         if self.n_ctx:
@@ -181,7 +183,7 @@ class ServerProcess:
         start_time = time.time()
         while time.time() - start_time < timeout_seconds:
             try:
-                response = self.make_request("GET", "/slots", headers={
+                response = self.make_request("GET", "/health", headers={
                     "Authorization": f"Bearer {self.api_key}" if self.api_key else None
                 })
                 if response.status_code == 200:
@@ -224,7 +226,7 @@ class ServerProcess:
         result.headers = dict(response.headers)
         result.status_code = response.status_code
         result.body = response.json() if parse_body else None
-        print("Response from server", result.body)
+        print("Response from server", json.dumps(result.body, indent=2))
         return result
 
     def make_stream_request(
@@ -245,7 +247,7 @@ class ServerProcess:
                 break
             elif line.startswith('data: '):
                 data = json.loads(line[6:])
-                print("Partial response from server", data)
+                print("Partial response from server", json.dumps(data, indent=2))
                 yield data
 
 
index c9fe7d966b2095e320e2d864aa6659a907a30c96..8f545aea52dc495dae2ceb02bc533feda360eac0 100644 (file)
@@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
     } else {
         throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
     }
+    if (result.empty()) {
+        throw std::runtime_error("\"prompt\" must not be empty");
+    }
     return result;
 }
 
@@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
     const std::string & chat_template) {
     json llama_params;
 
-    llama_params["__oaicompat"] = true;
-
     // Apply chat template to the list of messages
     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
 
@@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
         {"content", content}
     };
 }
+
+static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
+    json data = json::array();
+    for (const auto & lb : logit_bias) {
+        data.push_back(json{
+            {"bias", lb.bias},
+            {"token", lb.token},
+        });
+    }
+    return data;
+}
+
+static std::string safe_json_to_str(json data) {
+    return data.dump(-1, ' ', false, json::error_handler_t::replace);
+}