]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : refactor slot input data, move tokenizer to HTTP thread (#10023)
authorXuan Son Nguyen <redacted>
Thu, 24 Oct 2024 19:51:22 +0000 (21:51 +0200)
committerGitHub <redacted>
Thu, 24 Oct 2024 19:51:22 +0000 (21:51 +0200)
* server : refactor slot input data, move tokenizer to HTTP thread

* move prompt_tokens.empty() check

* fix incorrect if branch

* fix infinite generation loop

* bring back infill validation

* add infill test

* try fixing format_infill

* fix test

* remove redundant code

* rename completion to inference

* update docs

* use llama_tokens everywhere

examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/infill.feature [new file with mode: 0644]
examples/server/tests/features/steps/steps.py
examples/server/utils.hpp

index 09f1aa249ab1fc84d9786eb47a67ed298503e57c..8f00fcc79329305cfc068137c8a03cad3253c26d 100644 (file)
@@ -319,6 +319,18 @@ node index.js
       - The prompt is a string or an array with the first element given as a string
       - The model's `tokenizer.ggml.add_bos_token` metadata is `true`
 
+    These input shapes and data type are allowed for `prompt`:
+
+      - Single string: `"string"`
+      - Single sequence of tokens: `[12, 34, 56]`
+      - Mixed tokens and strings: `[12, 34, "string", 56, 78]`
+
+    Multiple prompts are also supported. In this case, the completion result will be an array.
+
+      - Only strings: `["string1", "string2"]`
+      - Strings and sequences of tokens: `["string1", [12, 34, 56]]`
+      - Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
+
     `temperature`: Adjust the randomness of the generated text. Default: `0.8`
 
     `dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
index 51f30ffeab9808263e9c00c13651c309d1482770..58f93694f684655b2a811df21e601b68676fbd67 100644 (file)
 #include <unordered_map>
 #include <unordered_set>
 
-#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
-
-#define SRV_INF(fmt, ...) LOG_INF("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_WRN(fmt, ...) LOG_WRN("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_ERR(fmt, ...) LOG_ERR("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define SRV_DBG(fmt, ...) LOG_DBG("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-
-#define QUE_INF(fmt, ...) LOG_INF("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_WRN(fmt, ...) LOG_WRN("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-#define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-
 using json = nlohmann::ordered_json;
 
 enum stop_type {
@@ -68,6 +53,7 @@ enum stop_type {
 // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
 enum slot_state {
     SLOT_STATE_IDLE,
+    SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
     SLOT_STATE_PROCESSING_PROMPT,
     SLOT_STATE_DONE_PROMPT,
     SLOT_STATE_GENERATING,
@@ -79,7 +65,7 @@ enum server_state {
 };
 
 enum server_task_type {
-    SERVER_TASK_TYPE_COMPLETION,
+    SERVER_TASK_TYPE_INFERENCE,
     SERVER_TASK_TYPE_CANCEL,
     SERVER_TASK_TYPE_NEXT_RESPONSE,
     SERVER_TASK_TYPE_METRICS,
@@ -89,21 +75,22 @@ enum server_task_type {
     SERVER_TASK_TYPE_SET_LORA,
 };
 
-enum server_task_cmpl_type {
-    SERVER_TASK_CMPL_TYPE_NORMAL,
-    SERVER_TASK_CMPL_TYPE_EMBEDDING,
-    SERVER_TASK_CMPL_TYPE_RERANK,
-    SERVER_TASK_CMPL_TYPE_INFILL,
+enum server_task_inf_type {
+    SERVER_TASK_INF_TYPE_COMPLETION,
+    SERVER_TASK_INF_TYPE_EMBEDDING,
+    SERVER_TASK_INF_TYPE_RERANK,
+    SERVER_TASK_INF_TYPE_INFILL,
 };
 
 struct server_task {
     int id        = -1; // to be filled by server_queue
     int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
 
+    llama_tokens prompt_tokens;
     server_task_type type;
     json data;
 
-    server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
+    server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
 
     // utility function
     static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -161,26 +148,20 @@ struct server_slot {
     int32_t i_batch     = -1;
     int32_t n_predict   = -1; // TODO: disambiguate from params.n_predict
 
+    // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
     int32_t n_prompt_tokens           = 0;
     int32_t n_prompt_tokens_processed = 0;
 
-    json prompt; // can be either a string, array of strings or array of token ids
-
-    json input_prefix;
-    json input_suffix;
-    json input_extra;
-
-    // when a task is submitted, we first tokenize the prompt and store it here
-    std::vector<llama_token> prompt_tokens;
-    std::vector<llama_token> extra_tokens;
+    // input prompt tokens
+    llama_tokens prompt_tokens;
 
     size_t last_nl_pos = 0;
 
     std::string generated_text;
-    std::vector<llama_token> cache_tokens;
+    llama_tokens cache_tokens;
     std::vector<completion_token_output> generated_token_probs;
 
-    server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
+    server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
 
     bool has_next_token = true;
     bool has_new_line   = false;
@@ -229,7 +210,7 @@ struct server_slot {
         n_past             = 0;
         n_sent_text        = 0;
         n_sent_token_probs = 0;
-        cmpl_type          = SERVER_TASK_CMPL_TYPE_NORMAL;
+        inf_type           = SERVER_TASK_INF_TYPE_COMPLETION;
 
         generated_token_probs.clear();
     }
@@ -734,42 +715,6 @@ struct server_context {
         metrics.init();
     }
 
-    std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
-        // If `add_bos` is true, we only add BOS, when json_prompt is a string,
-        // or the first element of the json_prompt array is a string.
-        std::vector<llama_token> prompt_tokens;
-
-        if (json_prompt.is_array()) {
-            bool first = true;
-            for (const auto & p : json_prompt) {
-                if (p.is_string()) {
-                    auto s = p.template get<std::string>();
-
-                    std::vector<llama_token> p;
-                    if (first) {
-                        p = common_tokenize(ctx, s, add_special, parse_special);
-                        first = false;
-                    } else {
-                        p = common_tokenize(ctx, s, false, parse_special);
-                    }
-
-                    prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
-                } else {
-                    if (first) {
-                        first = false;
-                    }
-
-                    prompt_tokens.push_back(p.template get<llama_token>());
-                }
-            }
-        } else {
-            auto s = json_prompt.template get<std::string>();
-            prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
-        }
-
-        return prompt_tokens;
-    }
-
     server_slot * get_slot_by_id(int id) {
         for (server_slot & slot : slots) {
             if (slot.id == id) {
@@ -794,22 +739,16 @@ struct server_context {
                     continue;
                 }
 
-                // skip the slot if it does not contains prompt
-                if (!slot.prompt.is_string()) {
+                // skip the slot if it does not contains cached tokens
+                if (slot.prompt_tokens.empty()) {
                     continue;
                 }
 
-                // current slot's prompt
-                std::string slot_prompt = slot.prompt.get<std::string>();
-
-                // length of the current slot's prompt
-                int slot_prompt_len = slot_prompt.size();
-
                 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
-                int lcp_len = longest_common_prefix(slot_prompt, prompt);
+                int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
 
                 // fraction of the common substring length compared to the current slot's prompt length
-                similarity = static_cast<float>(lcp_len) / slot_prompt_len;
+                similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
 
                 // select the current slot if the criteria match
                 if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
@@ -914,57 +853,6 @@ struct server_context {
             SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
         }
 
-        // infill
-        slot.input_prefix = json_value(data, "input_prefix", json());
-        slot.input_suffix = json_value(data, "input_suffix", json());
-        slot.input_extra  = json_value(data, "input_extra",  json());
-
-        SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
-        for (const auto & chunk : slot.input_extra) {
-            // { "text": string, "filename": string }
-            if (!chunk.contains("text") || !chunk["text"].is_string()) {
-                send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
-                return false;
-            }
-
-            // filename is optional
-            if (chunk.contains("filename") && !chunk["filename"].is_string()) {
-                send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
-                return false;
-            }
-
-            SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
-        }
-
-        // get prompt
-        {
-            const auto & prompt = data.find("prompt");
-            if (prompt == data.end()) {
-                send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
-                return false;
-            }
-
-            if ((prompt->is_string()) ||
-                (prompt->is_array() &&  prompt->size() == 1 && prompt->at(0).is_string()) ||
-                (prompt->is_array() && !prompt->empty()     && prompt->at(0).is_number_integer())) {
-                slot.prompt = *prompt;
-            } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
-                slot.prompt = prompt->at(0);
-            } else if (prompt->is_array() && prompt->size() > 1) {
-                // array of strings
-                for (const auto & el : *prompt) {
-                    if (!el.is_string()) {
-                        send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
-                        return false;
-                    }
-                }
-                slot.prompt = *prompt;
-            } else {
-                send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
-                return false;
-            }
-        }
-
         {
             slot.sparams.logit_bias.clear();
 
@@ -1044,8 +932,7 @@ struct server_context {
             }
         }
 
-        slot.state = SLOT_STATE_PROCESSING_PROMPT;
-        slot.prompt_tokens.clear();
+        slot.state = SLOT_STATE_STARTED;
 
         SLT_INF(slot, "%s", "processing task\n");
 
@@ -1297,7 +1184,7 @@ struct server_context {
         };
 
         if (slot.sparams.n_probs > 0) {
-            const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
+            const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
             const size_t probs_pos      = std::min(slot.n_sent_token_probs,                       slot.generated_token_probs.size());
             const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
 
@@ -1333,7 +1220,7 @@ struct server_context {
             {"tokens_predicted",    slot.n_decoded},
             {"tokens_evaluated",    slot.n_prompt_tokens},
             {"generation_settings", get_formated_generation(slot)},
-            {"prompt",              slot.prompt},
+            {"prompt",              common_detokenize(ctx, slot.prompt_tokens)},
             {"has_new_line",        slot.has_new_line},
             {"truncated",           slot.truncated},
             {"stopped_eos",         slot.stopped_eos},
@@ -1348,7 +1235,7 @@ struct server_context {
         if (slot.sparams.n_probs > 0) {
             std::vector<completion_token_output> probs;
             if (!slot.params.stream && slot.stopped_word) {
-                const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
+                const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
 
                 size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
                 probs = std::vector<completion_token_output>(
@@ -1457,19 +1344,17 @@ struct server_context {
     // Functions to create new task(s) and receive result(s)
     //
 
-    std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
+    // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
+    std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
         std::vector<server_task> tasks;
-        auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
+        auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
+            SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
             server_task task;
-            task.id        = queue_tasks.get_new_id();
-            task.cmpl_type = cmpl_type;
-            task.type      = SERVER_TASK_TYPE_COMPLETION;
-            if (replace_prompt) {
-                task.data  = task_data;
-                task.data["prompt"] = std::move(prompt);
-            } else {
-                task.data  = std::move(task_data);
-            }
+            task.id            = queue_tasks.get_new_id();
+            task.inf_type      = inf_type;
+            task.type          = SERVER_TASK_TYPE_INFERENCE;
+            task.data          = task_data;
+            task.prompt_tokens = std::move(prompt_tokens);
             tasks.push_back(std::move(task));
         };
 
@@ -1478,41 +1363,49 @@ struct server_context {
             throw std::runtime_error(error_msg);
         }
 
-        json prompt = data.at("prompt");
-
-        // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
-        if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
-            data["index"] = 0;
-            create_task(data, false, nullptr);
-        } else if (prompt.is_array()) {
-            // otherwise, it's a multiple-prompt task, we break it into smaller tasks
-            std::vector<json> prompts = prompt;
-            if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
-                // prompts[0] is the question
-                // the rest are the answers/documents
-                SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
-                for (size_t i = 1; i < prompts.size(); i++) {
-                    json qd;
-                    qd.push_back(prompts[0]);
-                    qd.push_back(prompts[i]);
-                    data["index"] = i - 1;
-                    create_task(data, true, qd);
-                }
-            } else {
-                SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
-                for (size_t i = 0; i < prompts.size(); i++) {
-                    const auto & e = prompts[i];
-                    if (e.is_string() || json_is_array_of_numbers(e)) {
+        // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
+        bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
+        std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
+        switch (inf_type) {
+            case SERVER_TASK_INF_TYPE_RERANK:
+                {
+                    // prompts[0] is the question
+                    // the rest are the answers/documents
+                    GGML_ASSERT(tokenized_prompts.size() > 1);
+                    SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
+                    for (size_t i = 1; i < tokenized_prompts.size(); i++) {
+                        data["index"] = i - 1;
+                        auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
+                        create_task(data, tokens);
+                    }
+                } break;
+            case SERVER_TASK_INF_TYPE_INFILL:
+                {
+                    SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
+                    for (size_t i = 0; i < tokenized_prompts.size(); i++) {
                         data["index"] = i;
-                        create_task(data, true, e);
-                    } else {
-                        throw std::runtime_error(error_msg);
+                        auto tokens = format_infill(
+                            ctx,
+                            data.at("input_prefix"),
+                            data.at("input_suffix"),
+                            data.at("input_extra"),
+                            params.n_batch,
+                            params.n_predict,
+                            slots[0].n_ctx, // TODO: there should be a better way
+                            params.spm_infill,
+                            tokenized_prompts[i]
+                        );
+                        create_task(data, tokens);
+                    }
+                } break;
+            default:
+                {
+                    SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
+                    for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+                        data["index"] = i;
+                        create_task(data, tokenized_prompts[i]);
                     }
                 }
-            }
-        } else {
-            // invalid case
-            throw std::runtime_error(error_msg);
         }
 
         return tasks;
@@ -1534,7 +1427,7 @@ struct server_context {
         queue_tasks.post(cancel_tasks, true);
     }
 
-    // receive the results from task(s) created by create_tasks_cmpl
+    // receive the results from task(s) created by create_tasks_inference
     void receive_cmpl_results(
             const std::unordered_set<int> & id_tasks,
             const std::function<void(std::vector<server_task_result>&)> & result_handler,
@@ -1558,7 +1451,7 @@ struct server_context {
         result_handler(results);
     }
 
-    // receive the results from task(s) created by create_tasks_cmpl, in stream mode
+    // receive the results from task(s) created by create_tasks_inference, in stream mode
     void receive_cmpl_results_stream(
             const std::unordered_set<int> & id_tasks, const
             std::function<bool(server_task_result&)> & result_handler, const
@@ -1591,7 +1484,7 @@ struct server_context {
 
     void process_single_task(const server_task & task) {
         switch (task.type) {
-            case SERVER_TASK_TYPE_COMPLETION:
+            case SERVER_TASK_TYPE_INFERENCE:
                 {
                     const int id_slot = json_value(task.data, "id_slot", -1);
 
@@ -1623,9 +1516,10 @@ struct server_context {
 
                     slot->reset();
 
-                    slot->id_task   = task.id;
-                    slot->cmpl_type = task.cmpl_type;
-                    slot->index     = json_value(task.data, "index", 0);
+                    slot->id_task       = task.id;
+                    slot->inf_type      = task.inf_type;
+                    slot->index         = json_value(task.data, "index", 0);
+                    slot->prompt_tokens = std::move(task.prompt_tokens);
 
                     if (!launch_slot_with_task(*slot, task)) {
                         SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
@@ -1658,7 +1552,7 @@ struct server_context {
                         slot_data["id"]         = slot.id;
                         slot_data["id_task"]    = slot.id_task;
                         slot_data["state"]      = slot.state;
-                        slot_data["prompt"]     = slot.prompt;
+                        slot_data["prompt"]     = common_detokenize(ctx, slot.prompt_tokens);
                         slot_data["next_token"] = {
                             {"has_next_token", slot.has_next_token},
                             {"has_new_line",   slot.has_new_line},
@@ -1785,9 +1679,6 @@ struct server_context {
                     }
                     slot->cache_tokens.resize(token_count);
 
-                    // TODO: maybe detokenize the slot->cache_tokens instead?
-                    slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
-
                     const int64_t t_end = ggml_time_us();
                     const double t_restore_ms = (t_end - t_start) / 1000.0;
 
@@ -1954,142 +1845,18 @@ struct server_context {
         if (params.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
                 // this slot still has a prompt to be processed
-                if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
+                if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
                     auto & prompt_tokens = slot.prompt_tokens;
 
-                    // we haven't tokenized the prompt yet - do it now:
-                    if (prompt_tokens.empty()) {
-                        SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
-
+                    // TODO: maybe move branch to outside of this loop in the future
+                    if (slot.state == SLOT_STATE_STARTED) {
                         slot.t_start_process_prompt = ggml_time_us();
                         slot.t_start_generation = 0;
-
-                        switch (slot.cmpl_type) {
-                            case SERVER_TASK_CMPL_TYPE_NORMAL:
-                            case SERVER_TASK_CMPL_TYPE_EMBEDDING:
-                                {
-                                    prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
-                                } break;
-                            case SERVER_TASK_CMPL_TYPE_RERANK:
-                                {
-                                    // require slot.prompt to be array of 2 strings
-                                    if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
-                                        SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
-                                        slot.release();
-                                        send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
-                                        continue;
-                                    }
-
-                                    // prompt: [BOS]query[EOS][SEP]doc[EOS]
-                                    prompt_tokens.clear();
-                                    prompt_tokens.push_back(llama_token_bos(model));
-                                    {
-                                        const auto part = tokenize(slot.prompt[0], false, false);
-                                        prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
-                                    }
-                                    prompt_tokens.push_back(llama_token_eos(model));
-                                    prompt_tokens.push_back(llama_token_sep(model));
-                                    {
-                                        const auto part = tokenize(slot.prompt[1], false, false);
-                                        prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
-                                    }
-                                    prompt_tokens.push_back(llama_token_eos(model));
-                                } break;
-                            case SERVER_TASK_CMPL_TYPE_INFILL:
-                                {
-                                    // TODO: optimize this block by reducing memory allocations and movement
-
-                                    // use FIM repo-level pattern:
-                                    // ref: https://arxiv.org/pdf/2409.12186
-                                    //
-                                    // [FIM_REP]myproject
-                                    // [FIM_SEP]filename0
-                                    // extra chunk 0
-                                    // [FIM_SEP]filename1
-                                    // extra chunk 1
-                                    // ...
-                                    // [FIM_SEP]filename
-                                    // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
-                                    //
-                                    auto tokens_prefix = tokenize(slot.input_prefix, false, false);
-                                    auto tokens_suffix = tokenize(slot.input_suffix, false, false);
-                                    auto tokens_prompt = tokenize(slot.prompt,       false, false);
-
-                                    slot.extra_tokens.clear();
-                                    if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
-                                        static const auto k_fim_repo = tokenize("myproject\n", false, false);
-
-                                        slot.extra_tokens.push_back(llama_token_fim_rep(model));
-                                        slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
-                                    }
-
-                                    for (const auto & chunk : slot.input_extra) {
-                                        // { "text": string, "filename": string }
-                                        const std::string text     = chunk.value("text", "");
-                                        const std::string filename = chunk.value("filename", "tmp");
-
-                                        if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
-                                            const auto k_fim_file = tokenize(filename + "\n", false, false);
-
-                                            slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
-                                            slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
-                                        } else {
-                                            // chunk separator in binary form to avoid confusing the AI
-                                            static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
-                                            static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
-
-                                            slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
-                                        }
-
-                                        const auto chunk_tokens = tokenize(text, false, false);
-                                        slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
-                                    }
-
-                                    if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
-                                        // TODO: current filename
-                                        static const auto k_fim_file = tokenize("filename\n", false, false);
-
-                                        slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
-                                        slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
-                                    }
-
-                                    // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
-                                    const int n_suffix_take = std::min<int>(tokens_suffix.size(),   (n_batch/4));
-                                    const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
-
-                                    // fill the rest of the context with extra chunks
-                                    const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
-
-                                    tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
-                                    tokens_suffix.resize(n_suffix_take);
-
-                                    tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
-                                    tokens_prefix.insert(tokens_prefix.end(),   tokens_prompt.begin(), tokens_prompt.end());
-                                    tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
-
-                                    auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
-                                    auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
-
-                                    if (llama_add_bos_token(model)) {
-                                        embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
-                                    }
-
-                                    SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
-
-                                    // put the extra context before the FIM prefix
-                                    embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
-
-                                    embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
-                                    embd_inp.push_back(llama_token_fim_mid(model));
-
-                                    prompt_tokens = std::move(embd_inp);
-                                } break;
-                        }
-
                         slot.n_past = 0;
                         slot.n_prompt_tokens = prompt_tokens.size();
+                        slot.state = SLOT_STATE_PROCESSING_PROMPT;
 
-                        SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
+                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
 
                         // print prompt tokens (for debugging)
                         if (1) {
@@ -2114,7 +1881,7 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                        if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.release();
@@ -2144,7 +1911,7 @@ struct server_context {
                                 const int n_block_size = n_left / 2;
                                 const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
 
-                                std::vector<llama_token> new_tokens(
+                                llama_tokens new_tokens(
                                         prompt_tokens.begin(),
                                         prompt_tokens.begin() + slot.params.n_keep);
 
@@ -2225,7 +1992,7 @@ struct server_context {
                     }
 
                     // non-causal tasks require to fit the entire prompt in the physical batch
-                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                    if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
                         // cannot fit the prompt in the current batch - will try next iter
                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
                             continue;
@@ -2234,8 +2001,8 @@ struct server_context {
 
                     // check that we are in the right batch_type, if not defer the slot
                     const bool slot_type =
-                        slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
-                        slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK     ? 1 : 0;
+                        slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
+                        slot.inf_type == SERVER_TASK_INF_TYPE_RERANK     ? 1 : 0;
 
                     if (batch_type == -1) {
                         batch_type = slot_type;
@@ -2353,7 +2120,7 @@ struct server_context {
                 }
 
                 if (slot.state == SLOT_STATE_DONE_PROMPT) {
-                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+                    if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
                         // prompt evaluated for embedding
                         send_embedding(slot, batch_view);
                         slot.release();
@@ -2361,7 +2128,7 @@ struct server_context {
                         continue; // continue loop of slots
                     }
 
-                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                    if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
                         send_rerank(slot, batch_view);
                         slot.release();
                         slot.i_batch = -1;
@@ -2915,13 +2682,13 @@ int main(int argc, char ** argv) {
         res_ok(res, {{ "success", true }});
     };
 
-    const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
+    const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
         if (ctx_server.params.embedding || ctx_server.params.reranking) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
-        std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
+        std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
         ctx_server.queue_results.add_waiting_tasks(tasks);
         ctx_server.queue_tasks.post(tasks);
 
@@ -2967,10 +2734,11 @@ int main(int argc, char ** argv) {
 
     const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
         json data = json::parse(req.body);
-        return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
+        return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
     };
 
     const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
+        // check model compatibility
         std::string err;
         if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
             err += "prefix token is missing. ";
@@ -2981,14 +2749,42 @@ int main(int argc, char ** argv) {
         if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
             err += "middle token is missing. ";
         }
-
         if (!err.empty()) {
             res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
         json data = json::parse(req.body);
-        return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
+
+        // validate input
+        if (!data.contains("input_prefix")) {
+            res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
+        }
+
+        if (!data.contains("input_suffix")) {
+            res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
+        }
+
+        if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
+            res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
+            return;
+        }
+        json input_extra = json_value(data, "input_extra", json::array());
+        for (const auto & chunk : input_extra) {
+            // { "text": string, "filename": string }
+            if (!chunk.contains("text") || !chunk.at("text").is_string()) {
+                res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
+                return;
+            }
+            // filename is optional
+            if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
+                res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
+                return;
+            }
+        }
+        data["input_extra"] = input_extra; // default to empty array if it's not exist
+
+        return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
     };
 
     // TODO: maybe merge this function with "handle_completions_generic"
@@ -3000,7 +2796,7 @@ int main(int argc, char ** argv) {
 
         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
 
-        std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
+        std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
         ctx_server.queue_results.add_waiting_tasks(tasks);
         ctx_server.queue_tasks.post(tasks);
 
@@ -3073,7 +2869,7 @@ int main(int argc, char ** argv) {
             const bool add_special = json_value(body, "add_special", false);
             const bool with_pieces = json_value(body, "with_pieces", false);
 
-            std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
+            llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
 
             if (with_pieces) {
                 for (const auto& token : tokens) {
@@ -3110,7 +2906,7 @@ int main(int argc, char ** argv) {
 
         std::string content;
         if (body.count("tokens") != 0) {
-            const std::vector<llama_token> tokens = body.at("tokens");
+            const llama_tokens tokens = body.at("tokens");
             content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
         }
 
@@ -3144,7 +2940,7 @@ int main(int argc, char ** argv) {
         json responses = json::array();
         bool error = false;
         {
-            std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
+            std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
             ctx_server.queue_results.add_waiting_tasks(tasks);
             ctx_server.queue_tasks.post(tasks);
 
@@ -3221,7 +3017,7 @@ int main(int argc, char ** argv) {
         json responses = json::array();
         bool error = false;
         {
-            std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
+            std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
             ctx_server.queue_results.add_waiting_tasks(tasks);
             ctx_server.queue_tasks.post(tasks);
 
diff --git a/examples/server/tests/features/infill.feature b/examples/server/tests/features/infill.feature
new file mode 100644 (file)
index 0000000..a0bbfef
--- /dev/null
@@ -0,0 +1,36 @@
+@llama.cpp
+@infill
+Feature: llama.cpp server
+
+  # The current model is made by adding FIM tokens to the existing stories260K
+  # We may want to use a better model in the future, maybe something like SmolLM 360M
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
+    And   a model file test-model-infill.gguf
+    And   a model alias tinyllama-infill
+    And   42 as server seed
+    And   1024 as batch size
+    And   1024 as ubatch size
+    And   2048 KV cache size
+    And   64 max tokens to predict
+    And   0.0 temperature
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Infill without input_extra
+    Given a prompt "Complete this"
+    And   an infill input extra none none
+    And   an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_"
+    And   an infill input suffix "}\n"
+    And   an infill request with no api error
+    Then  64 tokens are predicted matching One|day|she|saw|big|scary|bird
+
+  Scenario: Infill with input_extra
+    Given a prompt "Complete this"
+    And   an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
+    And   an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_"
+    And   an infill input suffix "}\n"
+    And   an infill request with no api error
+    Then  64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"
index 540a2ecd563746cb3bee72e241d17fc340024e4a..2e418d8aa571b620e0c07b4a666da64926d630ac 100644 (file)
@@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.lora_file = None
     context.disable_ctx_shift = False
 
+    # infill
+    context.infill_input_extra = None
+    context.infill_input_suffix = ''
+    context.infill_input_prefix = ''
+
     context.tasks_result = []
     context.concurrent_tasks = []
     context.prompts = []
@@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
         assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
 
 
+@step('an infill request with {api_error} api error')
+@async_run_until_complete
+async def step_request_completion(context, api_error: Literal['raised'] | str):
+    if api_error != 'no':
+        raise ValueError(f'api_error={api_error} is not yet implemented')
+    payload = {
+        "prompt": context.prompts[0],
+        "input_suffix": context.infill_input_suffix,
+        "input_prefix": context.infill_input_prefix,
+        "n_predict": context.n_predict,
+        "seed": context.seed,
+        "temperature": context.temperature,
+    }
+    if context.infill_input_extra is not None:
+        payload['input_extra'] = context.infill_input_extra
+    async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
+        async with session.post(f'{context.base_url}/infill',
+                                json=payload) as response:
+            assert response.status == 200
+            context.tasks_result = [await response.json()]
+
+
 @step('{predicted_n:d} tokens are predicted matching {re_content}')
 def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
     context.completion = context.tasks_result.pop()
@@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
     context.n_prompts = len(context.prompts)
 
 
+# TODO: allow this to be repeated
+@step('an infill input extra {filename} {text}')
+def step_infill_input_extra(context, filename, text):
+    if filename == 'none':
+        context.infill_input_extra = None
+    else:
+        context.infill_input_extra = [{'filename': filename, 'text': text}]
+
+
+@step('an infill input suffix {text}')
+def step_infill_input_suffix(context, text):
+    context.infill_input_suffix = text
+
+
+@step('an infill input prefix {text}')
+def step_infill_input_prefix(context, text):
+    context.infill_input_prefix = text
+
+
 @step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
 def step_many_prompts(context, num_prompts, prompt, seed):
     if context.seed is None:
index 69519ef95b2d913bea680d8ba00eb246daca95d8..81124206241851e1df27a03706f1bdc0b8841cad 100644 (file)
 #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 
 using json = nlohmann::ordered_json;
+using llama_tokens = std::vector<llama_token>;
+
+#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+
+#define SRV_INF(fmt, ...) LOG_INF("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_WRN(fmt, ...) LOG_WRN("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_ERR(fmt, ...) LOG_ERR("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_DBG(fmt, ...) LOG_DBG("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+#define QUE_INF(fmt, ...) LOG_INF("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_WRN(fmt, ...) LOG_WRN("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 
 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
 enum error_type {
@@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul
 }
 
 //
-// chat template utils
+// tokenizer and input processing utils
 //
 
+static bool json_is_array_of_numbers(const json & data) {
+    if (data.is_array()) {
+        for (const auto & e : data) {
+            if (!e.is_number_integer()) {
+                return false;
+            }
+        }
+        return true;
+    }
+    return false;
+}
+
+// is array having BOTH numbers & strings?
+static bool json_is_array_of_mixed_numbers_strings(const json & data) {
+    bool seen_string = false;
+    bool seen_number = false;
+    if (data.is_array()) {
+        for (const auto & e : data) {
+            seen_string |= e.is_string();
+            seen_number |= e.is_number_integer();
+            if (seen_number && seen_string) {
+                return true;
+            }
+        }
+    }
+    return false;
+}
+
+/**
+ * this handles 2 cases:
+ * - only string, example: "string"
+ * - mixed string and tokens, example: [12, 34, "string", 56, 78]
+ */
+static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
+    // If `add_bos` is true, we only add BOS, when json_prompt is a string,
+    // or the first element of the json_prompt array is a string.
+    llama_tokens prompt_tokens;
+
+    if (json_prompt.is_array()) {
+        bool first = true;
+        for (const auto & p : json_prompt) {
+            if (p.is_string()) {
+                auto s = p.template get<std::string>();
+
+                llama_tokens p;
+                if (first) {
+                    p = common_tokenize(ctx, s, add_special, parse_special);
+                    first = false;
+                } else {
+                    p = common_tokenize(ctx, s, false, parse_special);
+                }
+
+                prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
+            } else {
+                if (first) {
+                    first = false;
+                }
+
+                prompt_tokens.push_back(p.template get<llama_token>());
+            }
+        }
+    } else {
+        auto s = json_prompt.template get<std::string>();
+        prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
+    }
+
+    return prompt_tokens;
+}
+
+/**
+ * break the input "prompt" object into multiple prompt if needed, then tokenize them
+ * this supports these cases:
+ * - "prompt": "string"
+ * - "prompt": [12, 34, 56]
+ * - "prompt": [12, 34, "string", 56, 78]
+ * and multiple prompts (multi-tasks):
+ * - "prompt": ["string1", "string2"]
+ * - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
+ */
+static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
+    std::vector<llama_tokens> result;
+    if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
+        // string or mixed
+        result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
+    } else if (json_is_array_of_numbers(json_prompt)) {
+        // array of tokens
+        result.push_back(json_prompt.get<llama_tokens>());
+    } else if (json_prompt.is_array()) {
+        // array of prompts
+        result.reserve(json_prompt.size());
+        for (const auto & p : json_prompt) {
+            if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
+                result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
+            } else if (json_is_array_of_numbers(p)) {
+                // array of tokens
+                result.push_back(p.get<llama_tokens>());
+            } else {
+                throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
+            }
+        }
+    } else {
+        throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
+    }
+    return result;
+}
+
+//
+// template utils
+//
+
+// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
+static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
+    llama_tokens result;
+    result.reserve(doc.size() + query.size() + 4);
+    result.push_back(llama_token_bos(model));
+    result.insert(result.end(), query.begin(), query.end());
+    result.push_back(llama_token_eos(model));
+    result.push_back(llama_token_sep(model));
+    result.insert(result.end(), doc.begin(), doc.end());
+    result.push_back(llama_token_eos(model));
+    return result;
+}
+
+// format infill task
+static llama_tokens format_infill(
+        const llama_context * ctx,
+        const json & input_prefix,
+        const json & input_suffix,
+        const json & input_extra,
+        const int n_batch,
+        const int n_predict,
+        const int n_ctx,
+        const bool spm_infill,
+        const llama_tokens & tokens_prompt
+    ) {
+    // TODO: optimize this block by reducing memory allocations and movement
+
+    // use FIM repo-level pattern:
+    // ref: https://arxiv.org/pdf/2409.12186
+    //
+    // [FIM_REP]myproject
+    // [FIM_SEP]filename0
+    // extra chunk 0
+    // [FIM_SEP]filename1
+    // extra chunk 1
+    // ...
+    // [FIM_SEP]filename
+    // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
+    //
+    llama_tokens extra_tokens;
+    extra_tokens.reserve(n_ctx);
+
+    auto model = llama_get_model(ctx);
+    auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
+    auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
+
+    if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
+        // TODO: make project name an input
+        static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
+
+        extra_tokens.push_back(llama_token_fim_rep(model));
+        extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
+    }
+    for (const auto & chunk : input_extra) {
+        // { "text": string, "filename": string }
+        const std::string text     = json_value(chunk, "text",     std::string());
+        const std::string filename = json_value(chunk, "filename", std::string("tmp"));
+
+        if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
+            const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
+
+            extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
+            extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+        } else {
+            // chunk separator in binary form to avoid confusing the AI
+            static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+            static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
+
+            extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
+        }
+
+        const auto chunk_tokens = common_tokenize(ctx, text, false, false);
+        extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
+    }
+
+    if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
+        // TODO: current filename
+        static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
+
+        extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
+        extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+    }
+
+    // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
+    const int n_suffix_take = std::min<int>(tokens_suffix.size(),   (n_batch/4));
+    const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
+
+    // fill the rest of the context with extra chunks
+    const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
+
+    tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
+    tokens_suffix.resize(n_suffix_take);
+
+    tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
+    tokens_prefix.insert(tokens_prefix.end(),   tokens_prompt.begin(), tokens_prompt.end());
+    tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
+
+    auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
+    auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
+
+    if (llama_add_bos_token(model)) {
+        embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
+    }
+
+    SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
+
+    // put the extra context before the FIM prefix
+    embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
+
+    embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
+    embd_inp.push_back(llama_token_fim_mid(model));
+
+    return embd_inp;
+}
+
 // Format given chat. If tmpl is empty, we take the template from model metadata
 inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
     std::vector<common_chat_msg> chat;
@@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
     return std::string::npos;
 }
 
-static bool json_is_array_of_numbers(const json & data) {
-    if (data.is_array()) {
-        for (const auto & e : data) {
-            if (!e.is_number()) {
-                return false;
-            }
-        }
-        return true;
-    }
-    return false;
-}
-
 // TODO: reuse llama_detokenize
 template <class Iter>
 static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {