]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: support multiple generations from one prompt (OAI "n" option) (#17775)
authorXuan-Son Nguyen <redacted>
Sat, 6 Dec 2025 14:54:38 +0000 (15:54 +0100)
committerGitHub <redacted>
Sat, 6 Dec 2025 14:54:38 +0000 (15:54 +0100)
* backend support

* server: support multiple generations from one prompt (OAI "n" option)

* fix invalid batch

* format oai

* clean up

* disable ctx shift

* add test

* update comments

* fix style

* add n_cmpl to docs [no ci]

* allowing using both n_cmpl and n

tools/server/README.md
tools/server/server-common.cpp
tools/server/server-common.h
tools/server/server-context.cpp
tools/server/server-task.cpp
tools/server/server-task.h
tools/server/tests/unit/test_chat_completion.py

index cb2fbcf8eb74b06130de27453d2a7728c9c435e8..bf274db79d41e300cfdd346a7abd523a065046eb 100644 (file)
@@ -493,6 +493,8 @@ Note for `multimodal_data` in JSON object prompts. This should be an array of st
 `n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
 By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
 
+`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
+
 `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
 
 `stop`: Specify a JSON array of stopping strings.
index cfdd0c656f41c37b5edd5ef446de942c065c39a1..b403864e0eebbfd8462f1828d573d101aa129d0c 100644 (file)
@@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
     return 0;
 }
 
+server_tokens server_tokens::clone() const {
+    server_tokens res;
+    res.has_mtmd = has_mtmd;
+    res.tokens   = tokens;
+    for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
+        size_t idx = it->first;
+        const mtmd::input_chunk_ptr & chunk = it->second;
+        res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
+    }
+    return res;
+}
+
 //
 // tokenizer and input processing utils
 //
@@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
 
-    // Handle "n" field
-    int n_choices = json_value(body, "n", 1);
-    if (n_choices != 1) {
-        throw std::runtime_error("Only one completion choice is allowed");
-    }
-
     // Handle "echo" field
     if (json_value(body, "echo", false)) {
         throw std::runtime_error("Only no echo is supported");
@@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
         llama_params["chat_parser"] = chat_params.parser;
     }
 
-    // Handle "n" field
-    int n_choices = json_value(body, "n", 1);
-    if (n_choices != 1) {
-        throw std::invalid_argument("Only one completion choice is allowed");
-    }
-
     // Handle "logprobs" field
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     if (json_value(body, "logprobs", false)) {
index bb04e82b4f5fdfaec836dc4f1c51d26096a72220..0c4d84ffa06cf9f6afc115672fa21d45c4c01fa4 100644 (file)
@@ -215,6 +215,8 @@ public:
                 llama_pos pos,
                 int32_t seq_id,
                 size_t & n_tokens_out) const;
+
+    server_tokens clone() const;
 };
 
 
index f3f2edc0cc4f2e75e3e519cec3581af0e0fe7ab7..12a4e94e5d8bb2f51c4700956ff74ee6ab47f105 100644 (file)
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
 // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
 enum slot_state {
     SLOT_STATE_IDLE,
-    SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
+    SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
+    SLOT_STATE_STARTED,    // after assigning a task and about to process prompt
     SLOT_STATE_PROCESSING_PROMPT,
     SLOT_STATE_DONE_PROMPT,
     SLOT_STATE_GENERATING,
@@ -254,6 +255,15 @@ struct server_slot {
         generated_token_probs.push_back(token);
     }
 
+    // note: a slot can also be either a parent or a child
+    bool is_parent() const {
+        return is_processing() && task->n_children > 0;
+    }
+
+    bool is_child() const {
+        return is_processing() && task->id_parent >= 0;
+    }
+
     void release() {
         if (is_processing()) {
             GGML_ASSERT(task);
@@ -383,6 +393,17 @@ struct server_slot {
 
         return res;
     }
+
+    void copy_state_to(server_slot & other) const {
+        llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
+        llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
+        other.n_decoded   = n_decoded;
+        other.n_remaining = n_remaining;
+        other.i_batch     = i_batch;
+        other.n_prompt_tokens_cache     = n_prompt_tokens_cache;
+        other.n_prompt_tokens_processed = n_prompt_tokens_processed;
+        other.prompt = prompt.clone();
+    }
 };
 
 
@@ -1022,7 +1043,9 @@ struct server_context_impl {
 
         slot.task = std::make_unique<const server_task>(std::move(task));
 
-        slot.state = SLOT_STATE_STARTED;
+        slot.state = slot.is_child()
+            ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
+            : SLOT_STATE_STARTED;
 
         SLT_INF(slot, "%s", "processing task\n");
 
@@ -1684,6 +1707,12 @@ struct server_context_impl {
                     GGML_ABORT("not supported by multimodal");
                 }
 
+                if (slot.is_parent() || slot.is_child()) {
+                    send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
+                    slot.release();
+                    continue;
+                }
+
                 // Shift context
                 int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
 
@@ -2308,6 +2337,26 @@ struct server_context_impl {
             n_batch = llama_n_batch(ctx);
 
             for (auto & slot : slots) {
+                // may need to copy state to other slots
+                if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
+                    std::vector<server_slot *> child_slots;
+                    for (auto & other : slots) {
+                        if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
+                            child_slots.push_back(&other);
+                        }
+                    }
+
+                    // we can only proceed if all child slots are having the correct tasks
+                    if (child_slots.size() == slot.task->n_children) {
+                        // copy state to the child slots
+                        for (auto & child : child_slots) {
+                            SLT_INF(slot, "copying state to child %d\n", child->id);
+                            slot.copy_state_to(*child);
+                            child->state = SLOT_STATE_DONE_PROMPT;
+                        }
+                    }
+                }
+
                 // optionally send prompt processing progress
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
                     if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
         }
         tasks.reserve(inputs.size());
         states.reserve(inputs.size());
+        int idx = 0;
         for (size_t i = 0; i < inputs.size(); i++) {
             server_task task = server_task(type);
 
             task.id    = ctx_server.queue_tasks.get_new_id();
-            task.index = i;
+            task.index = idx++;
 
             task.tokens = std::move(inputs[i]);
             task.params = server_task::params_from_json_cmpl(
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             task.params.oaicompat_model   = ctx_server.model_name;
             states.push_back(task.params.oaicompat_chat_syntax);
 
+            if (task.params.n_cmpl > 1) {
+                task.n_children = task.params.n_cmpl - 1;
+                for (size_t j = 0; j < task.n_children; j++) {
+                    server_task child = task.create_child(
+                        task.id,
+                        ctx_server.queue_tasks.get_new_id(),
+                        idx++);
+                    states.push_back(child.params.oaicompat_chat_syntax);
+                    tasks.push_back(std::move(child));
+                }
+            }
+
             tasks.push_back(std::move(task));
         }
 
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
                 GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
                 arr.push_back(res->to_json());
             }
-            // if single request, return single object instead of array
-            res->ok(arr.size() == 1 ? arr[0] : arr);
+            GGML_ASSERT(!arr.empty() && "empty results");
+            if (arr.size() == 1) {
+                // if single request, return single object instead of array
+                res->ok(arr[0]);
+            } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
+                // if multiple results in OAI format, we need to re-format them
+                json & choices = arr[0]["choices"];
+                for (size_t i = 1; i < arr.size(); i++) {
+                    choices.push_back(std::move(arr[i]["choices"][0]));
+                }
+                res->ok(arr[0]);
+            } else {
+                // multi-results, non-OAI compat
+                res->ok(arr);
+            }
         }
     } else {
         // in streaming mode, the first error must be treated as non-stream response
index df066264778a8d1ada9de96e6dc2cc303199d0e1..c401f47a788ff6924074e95f7f824e5b45012aa6 100644 (file)
@@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
     params.n_indent         = json_value(data,       "n_indent",           defaults.n_indent);
     params.n_keep           = json_value(data,       "n_keep",             defaults.n_keep);
     params.n_discard        = json_value(data,       "n_discard",          defaults.n_discard);
+    params.n_cmpl           = json_value(data,       "n_cmpl",             json_value(data, "n", 1));
     //params.t_max_prompt_ms  = json_value(data,       "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
     params.t_max_predict_ms = json_value(data,       "t_max_predict_ms",   defaults.t_max_predict_ms);
     params.response_fields  = json_value(data,       "response_fields",    std::vector<std::string>());
@@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
         }
     }
 
+    if (params.n_cmpl > params_base.n_parallel) {
+        throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
+    }
+
     return params;
 }
 
@@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
 
     json choice {
         {"finish_reason", finish_reason},
-        {"index", 0},
+        {"index", index},
         {"message", msg.to_json_oaicompat<json>()},
     };
 
@@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
             {"choices", json::array({
                 json {
                     {"finish_reason", nullptr},
-                    {"index", 0},
+                    {"index", index},
                     {"delta", delta},
                 },
             })},
index 8e7b9e3e310fb517e44d95d1b2dc97ea4209f064..4e4840fc83beb88d02ba1f463effed8e29d0daab 100644 (file)
@@ -53,6 +53,7 @@ struct task_params {
     int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
     int32_t n_predict = -1; // new tokens to predict
     int32_t n_indent  =  0; // minimum line indentation for the generated text in number of whitespace characters
+    int32_t n_cmpl    =  1; // number of completions to generate from this prompt
 
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -89,6 +90,10 @@ struct server_task {
     int id_target = -1;
     int id_slot   = -1;
 
+    // used by parallel sampling (multiple completions from same prompt)
+    size_t n_children =  0; // number of tasks reusing this prompt
+    int    id_parent  = -1;
+
     // used by SERVER_TASK_TYPE_INFERENCE
     task_params   params;
     server_tokens tokens;
@@ -130,6 +135,17 @@ struct server_task {
         }
         return ids;
     }
+
+    server_task create_child(int id_parent, int id_child, int idx) const {
+        server_task copy;
+        copy.id        = id_child;
+        copy.index     = idx;
+        copy.id_parent = id_parent;
+        copy.params    = params;
+        copy.type      = type;
+        copy.tokens    = tokens.clone();
+        return copy;
+    }
 };
 
 struct result_timings {
@@ -466,6 +482,14 @@ struct server_prompt {
     int n_tokens() const {
         return tokens.size();
     }
+
+    server_prompt clone() const {
+        return server_prompt {
+            tokens.clone(),
+            data,
+            checkpoints
+        };
+    }
 };
 
 struct server_prompt_cache {
index aa6229c93a50c8cc3867e2833665dfb63b214ed3..64f3158b986f72799d2e186810d382f3383bd4c4 100644 (file)
@@ -477,3 +477,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
     assert last_progress["total"] > 0
     assert last_progress["processed"] == last_progress["total"]
     assert total_batch_count == batch_count
+
+
+def test_chat_completions_multiple_choices():
+    global server
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": 8,
+        "n": 2,
+        "messages": [
+            {"role": "system", "content": "Book"},
+            {"role": "user", "content": "What is the best book"},
+        ],
+    })
+    assert res.status_code == 200
+    assert len(res.body["choices"]) == 2
+    for choice in res.body["choices"]:
+        assert "assistant" == choice["message"]["role"]
+        assert match_regex("Suddenly", choice["message"]["content"])
+        assert choice["finish_reason"] == "length"