]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: fix n_cmpl not skipping processing prompt (#18663)
authorXuan-Son Nguyen <redacted>
Fri, 9 Jan 2026 23:00:41 +0000 (00:00 +0100)
committerGitHub <redacted>
Fri, 9 Jan 2026 23:00:41 +0000 (00:00 +0100)
* server: fix n_cmpl not skipping processing

* fix infinite loop on empty batch

* cont : init child samplers + modify child logic

* cont : cleanup

* cont : improve n_cmpl logic

- launch the parent task first so it finds the slot with best cache
- parent task waits for child tasks to be launched
- when a child task finishes - remove its cache

* cont : remove redundant function

* cont : reduce parent checks

* fix : nullptr task dereference

---------

Co-authored-by: Georgi Gerganov <redacted>
tools/server/server-context.cpp
tools/server/server-task.h

index 324c3af30c14c5b34d1aa8fbad40ac8a57946d0b..af6e053424381c45fbd272c05feba5dfceeeb82e 100644 (file)
@@ -79,6 +79,8 @@ struct server_slot {
 
     common_speculative * spec = nullptr;
 
+    // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
+    //       see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
     std::unique_ptr<const server_task> task;
     std::unique_ptr<const server_task> task_prev; // used for debugging
 
@@ -153,7 +155,7 @@ struct server_slot {
 
     common_sampler_ptr smpl;
 
-    llama_token sampled; // in speculative mode, this is the last accepted token
+    llama_token  sampled; // in speculative mode, this is the last accepted token
     llama_tokens drafted;
 
     // stats
@@ -201,12 +203,46 @@ struct server_slot {
         alora_invocation_start = -1;
     }
 
+    // remove cached prompt + tokens
+    void clear(bool allow_processing) {
+        if (!allow_processing) {
+            GGML_ASSERT(!is_processing());
+        }
+
+        SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
+
+        llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
+        prompt.tokens.clear();
+    }
+
+    void init_sampler() const {
+        const int64_t t_start = ggml_time_us();
+
+        common_sampler_reset(smpl.get());
+
+        int n_text = 0;
+
+        for (int i = 0; i < (int) prompt.tokens.size(); i++) {
+            const llama_token id = prompt.tokens[i];
+
+            if (id != LLAMA_TOKEN_NULL) {
+                common_sampler_accept(smpl.get(), id, false);
+                n_text++;
+            }
+        }
+
+        SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
+                (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
+    }
+
+    // TODO: move to server_task
     bool need_embd() const {
         GGML_ASSERT(task);
 
         return server_task_type_need_embd(task->type);
     }
 
+    // TODO: move to server_task
     bool need_logits() const {
         GGML_ASSERT(task);
 
@@ -258,10 +294,13 @@ struct server_slot {
             SLT_WRN(*this, "%s", "slot is not processing\n");
             return;
         }
+
         generated_token_probs.push_back(token);
     }
 
     int get_n_draft_max() const {
+        GGML_ASSERT(task);
+
         if (!can_speculate()) {
             return 0;
         }
@@ -287,12 +326,14 @@ struct server_slot {
     }
 
     // note: a slot can also be either a parent or a child
+    // TODO: move to server_task
     bool is_parent() const {
-        return is_processing() && task->n_children > 0;
+        return task->n_children > 0;
     }
 
+    // TODO: move to server_task
     bool is_child() const {
-        return is_processing() && task->id_parent >= 0;
+        return task->id_parent >= 0;
     }
 
     void release() {
@@ -301,10 +342,16 @@ struct server_slot {
 
             SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
 
-            t_last_used = ggml_time_us();
+            t_last_used        =  ggml_time_us();
             t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
+
             state = SLOT_STATE_IDLE;
 
+            // do not keep context of the child slots - the parent's context is enough
+            if (is_child()) {
+                clear(false);
+            }
+
             task_prev = std::move(task);
             task.reset();
 
@@ -425,14 +472,22 @@ struct server_slot {
     }
 
     void copy_state_to(server_slot & other) const {
-        llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
-        llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
+        GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
+
+        llama_memory_seq_rm(llama_get_memory(ctx), other.id,     -1, -1);
+        llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
+
         other.n_decoded   = n_decoded;
         other.n_remaining = n_remaining;
         other.i_batch     = i_batch;
+
+        other.t_start_process_prompt    = t_start_process_prompt;
+        other.t_prompt_processing       = t_prompt_processing;
         other.n_prompt_tokens_cache     = n_prompt_tokens_cache;
         other.n_prompt_tokens_processed = n_prompt_tokens_processed;
+
         other.prompt = prompt.clone();
+        other.init_sampler();
     }
 };
 
@@ -745,6 +800,7 @@ private:
         }
 
         slots.clear();
+
         for (int i = 0; i < params_base.n_parallel; i++) {
             server_slot slot;
 
@@ -993,7 +1049,7 @@ private:
                 ret->prompt_save(*prompt_cache);
 
                 if (!ret->prompt_load(*prompt_cache, task.tokens)) {
-                    clear_slot(*ret);
+                    ret->clear(false);
                 }
 
                 prompt_cache->update();
@@ -1005,17 +1061,6 @@ private:
         return ret;
     }
 
-    void clear_slot(server_slot & slot, bool allow_processing = false) const {
-        if (!allow_processing) {
-            GGML_ASSERT(!slot.is_processing());
-        }
-
-        SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
-
-        llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
-        slot.prompt.tokens.clear();
-    }
-
     // return true if at least one slot has been cleared
     // TODO: improve logic
     //       - smarter decision which slot to clear (LRU or longest prompt?)
@@ -1036,7 +1081,7 @@ private:
             if (slot.prompt.n_tokens() > 0) {
                 SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
 
-                clear_slot(slot);
+                slot.clear(false);
 
                 res = true;
 
@@ -1182,7 +1227,7 @@ private:
             ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
             : SLOT_STATE_STARTED;
 
-        SLT_INF(slot, "%s", "processing task\n");
+        SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());
 
         return true;
     }
@@ -1819,7 +1864,7 @@ private:
                     // Erase token cache
                     const size_t n_erased = slot->prompt.tokens.size();
 
-                    clear_slot(*slot);
+                    slot->clear(false);
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
                     res->id       = task.id;
@@ -2053,8 +2098,29 @@ private:
                     continue;
                 }
 
+                // check if this is a child slot
+                if (slot.state == SLOT_STATE_WAIT_OTHER) {
+                    SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
+                    continue;
+                }
+
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
+                    // wait for all children to be launched
+                    if (slot.is_parent()) {
+                        int n_launched = 0;
+                        for (auto & other : slots) {
+                            if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
+                                ++n_launched;
+                            }
+                        }
+
+                        if (n_launched < slot.task->n_children) {
+                            SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
+                            continue;
+                        }
+                    }
+
                     const auto & input_tokens = slot.task->tokens;
 
                     // TODO: maybe move branch to outside of this loop in the future
@@ -2355,7 +2421,7 @@ private:
                     if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
                         SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
 
-                        clear_slot(slot, /*allow_processing=*/true);
+                        slot.clear(true);
 
                         // there is no common part left
                         slot.n_prompt_tokens_cache = 0;
@@ -2455,16 +2521,6 @@ private:
 
                         GGML_ASSERT(batch.n_tokens > 0);
 
-                        common_sampler_reset(slot.smpl.get());
-
-                        // Process all prompt tokens through sampler system
-                        for (int i = 0; i < slot.task->n_tokens(); ++i) {
-                            llama_token id = input_tokens[i];
-                            if (id != LLAMA_TOKEN_NULL) {
-                                common_sampler_accept(slot.smpl.get(), id, false);
-                            }
-                        }
-
                         // extract the logits only for the last token
                         batch.logits[batch.n_tokens - 1] = true;
 
@@ -2473,6 +2529,8 @@ private:
 
                         SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
 
+                        slot.init_sampler();
+
                         const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                         const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
 
@@ -2519,11 +2577,6 @@ private:
             }
         }
 
-        if (batch.n_tokens == 0) {
-            SRV_WRN("%s", "no tokens to decode\n");
-            return;
-        }
-
         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
 
         if (slot_batched) {
@@ -2540,6 +2593,10 @@ private:
             llama_set_embeddings(ctx, slot_batched->need_embd());
         }
 
+        if (batch.n_tokens == 0) {
+            SRV_WRN("%s", "no tokens to decode\n");
+        }
+
         int32_t i_next = 0;
 
         // process the created batch of tokens
@@ -2591,7 +2648,7 @@ private:
 
                                 // note: it's complicated to keep track of how much of the current batch has been
                                 //       processed before the error occurred, so we simply clear the entire context
-                                clear_slot(slot);
+                                slot.clear(false);
                             }
                         }
 
@@ -2615,27 +2672,34 @@ private:
             // on successful decode, restore the original batch size
             n_batch = llama_n_batch(ctx);
 
+            // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
             for (auto & slot : slots) {
-                // may need to copy state to other slots
                 if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
-                    std::vector<server_slot *> child_slots;
+                    SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);
+
+                    std::vector<server_slot *> children;
                     for (auto & other : slots) {
                         if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
-                            child_slots.push_back(&other);
+                            children.push_back(&other);
                         }
                     }
 
                     // we can only proceed if all child slots are having the correct tasks
-                    if (child_slots.size() == slot.task->n_children) {
+                    if (slot.task->n_children == (int) children.size()) {
                         // copy state to the child slots
-                        for (auto & child : child_slots) {
-                            SLT_INF(slot, "copying state to child %d\n", child->id);
+                        for (auto & child : children) {
+                            SLT_INF(slot, " - copying state to child %d\n", child->id);
+
+                            GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
+
                             slot.copy_state_to(*child);
                             child->state = SLOT_STATE_DONE_PROMPT;
                         }
                     }
                 }
+            }
 
+            for (auto & slot : slots) {
                 // optionally send prompt processing progress
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
                     if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -2720,7 +2784,7 @@ private:
                     continue;
                 }
 
-                size_t n_draft = slot.drafted.size();
+                const size_t n_draft = slot.drafted.size();
 
                 // the accepted tokens from the speculation
                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
@@ -2923,9 +2987,11 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
             task.params.oaicompat_cmpl_id = completion_id;
             task.params.oaicompat_model   = meta->model_name;
 
+            // prepare child tasks
             if (task.params.n_cmpl > 1) {
                 task.n_children = task.params.n_cmpl - 1;
-                for (size_t j = 0; j < task.n_children; j++) {
+
+                for (int j = 0; j < task.n_children; j++) {
                     server_task child = task.create_child(task.id, rd.get_new_id());
 
                     // use different sampling seed for each child
@@ -2938,7 +3004,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
                 }
             }
 
-            tasks.push_back(std::move(task));
+            // note: the parent task always launches first
+            tasks.insert(tasks.begin(), std::move(task));
         }
 
         rd.post_tasks(std::move(tasks));
index ead149118214c2f8b268a2f39d358361f79dc9d5..cf08fced631a01c623dbec02bac38aa3ee478692 100644 (file)
@@ -121,8 +121,8 @@ struct server_task {
     int id_slot   = -1;
 
     // used by parallel sampling (multiple completions from same prompt)
-    size_t n_children =  0; // number of tasks reusing this prompt
-    int    id_parent  = -1;
+    int n_children =  0; // number of tasks reusing this prompt
+    int id_parent  = -1;
 
     // used by SERVER_TASK_TYPE_INFERENCE
     task_params   params;
@@ -173,11 +173,13 @@ struct server_task {
 
     server_task create_child(int id_parent, int id_child) const {
         server_task copy;
+
         copy.id        = id_child;
         copy.id_parent = id_parent;
         copy.params    = params;
         copy.type      = type;
         copy.tokens    = tokens.clone();
+
         return copy;
     }