]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : implement prompt processing progress report in stream mode (#15827)
authorXuan-Son Nguyen <redacted>
Sat, 6 Sep 2025 11:35:04 +0000 (18:35 +0700)
committerGitHub <redacted>
Sat, 6 Sep 2025 11:35:04 +0000 (13:35 +0200)
* server : implement `return_progress`

* add timings.cache_n

* add progress.time_ms

* add test

* fix test for chat/completions

* readme: add docs on timings

* use ggml_time_us

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
tools/server/README.md
tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py

index b0527f3cbea285dacae84ed3dd71827d9c07e2b1..73b4cc6f03a28ec895223be1e01866c26772f7db 100644 (file)
@@ -512,6 +512,8 @@ These words will not be included in the completion, so make sure to add them to
 
 `timings_per_token`: Include prompt processing and text generation speed information in each response.  Default: `false`
 
+`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 3 values: `total`, `cache` and `processed`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. Default: `false`
+
 `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
 
 `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
@@ -1276,6 +1278,34 @@ curl http://localhost:8080/v1/chat/completions \
 
 **See our [Function calling](../../docs/function-calling.md) docs** for more details, supported native tool call styles (generic tool call style is used as fallback) / examples of use.
 
+*Timings and context usage*
+
+The response contains a `timings` object, for example:
+
+```js
+{
+  "choices": [],
+  "created": 1757141666,
+  "id": "chatcmpl-ecQULm0WqPrftUqjPZO1CFYeDjGZNbDu",
+  // ...
+  "timings": {
+    "cache_n": 236, // number of prompt tokens reused from cache
+    "prompt_n": 1, // number of prompt tokens being processed
+    "prompt_ms": 30.958,
+    "prompt_per_token_ms": 30.958,
+    "prompt_per_second": 32.301828283480845,
+    "predicted_n": 35, // number of predicted tokens
+    "predicted_ms": 661.064,
+    "predicted_per_token_ms": 18.887542857142858,
+    "predicted_per_second": 52.94494935437416
+  }
+}
+```
+
+This provides information on the performance of the server. It also allows calculating the current context usage.
+
+The total number of tokens in context is equal to `prompt_n + cache_n + predicted_n`
+
 ### POST `/v1/embeddings`: OpenAI-compatible embeddings API
 
 This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
index 73fc43bada5437a5a9c765ecbf7b01cc36763982..70aa18756465f0b0ade4c466bd4ca1c94ea279c1 100644 (file)
@@ -110,9 +110,10 @@ static bool server_task_type_need_logits(server_task_type task_type) {
 }
 
 struct slot_params {
-    bool stream        = true;
-    bool cache_prompt  = true; // remember the prompt to avoid reprocessing all prompt
-    bool return_tokens = false;
+    bool stream          = true;
+    bool cache_prompt    = true; // remember the prompt to avoid reprocessing all prompt
+    bool return_tokens   = false;
+    bool return_progress = false;
 
     int32_t n_keep    =  0; // number of tokens to keep from initial prompt
     int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -307,11 +308,11 @@ struct server_task {
 
         // enabling this will output extra debug information in the HTTP responses from the server
         params.verbose           = params_base.verbosity > 9;
-        params.timings_per_token = json_value(data, "timings_per_token", false);
 
         params.stream           = json_value(data, "stream",             false);
         params.cache_prompt     = json_value(data, "cache_prompt",       true);
         params.return_tokens    = json_value(data, "return_tokens",      false);
+        params.return_progress  = json_value(data, "return_progress",    false);
         params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
         params.n_indent         = json_value(data, "n_indent",           defaults.n_indent);
         params.n_keep           = json_value(data, "n_keep",             defaults.n_keep);
@@ -608,6 +609,8 @@ struct server_task {
 };
 
 struct result_timings {
+    int32_t cache_n = -1;
+
     int32_t prompt_n = -1;
     double prompt_ms;
     double prompt_per_token_ms;
@@ -624,6 +627,8 @@ struct result_timings {
 
     json to_json() const {
         json base = {
+            {"cache_n",                cache_n},
+
             {"prompt_n",               prompt_n},
             {"prompt_ms",              prompt_ms},
             {"prompt_per_token_ms",    prompt_per_token_ms},
@@ -644,6 +649,22 @@ struct result_timings {
     }
 };
 
+struct result_prompt_progress {
+    int32_t total = 0;
+    int32_t cache = 0;
+    int32_t processed = 0;
+    int64_t time_ms = 0;
+
+    json to_json() const {
+        return json {
+            {"total",     total},
+            {"cache",     cache},
+            {"processed", processed},
+            {"time_ms",   time_ms},
+        };
+    }
+};
+
 struct server_task_result {
     int id           = -1;
     int id_slot      = -1;
@@ -999,8 +1020,10 @@ struct server_task_result_cmpl_partial : server_task_result {
     int32_t n_prompt_tokens;
 
     bool post_sampling_probs;
+    bool is_progress = false;
     completion_token_output prob_output;
     result_timings timings;
+    result_prompt_progress progress;
 
     // OAI-compat fields
     bool            verbose   = false;
@@ -1045,6 +1068,9 @@ struct server_task_result_cmpl_partial : server_task_result {
         if (timings.prompt_n > 0) {
             res.push_back({"timings", timings.to_json()});
         }
+        if (is_progress) {
+            res.push_back({"prompt_progress", progress.to_json()});
+        }
         if (!prob_output.probs.empty()) {
             res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
         }
@@ -1082,6 +1108,9 @@ struct server_task_result_cmpl_partial : server_task_result {
         if (timings.prompt_n >= 0) {
             res.push_back({"timings", timings.to_json()});
         }
+        if (is_progress) {
+            res.push_back({"prompt_progress", progress.to_json()});
+        }
 
         return res;
     }
@@ -1109,7 +1138,7 @@ struct server_task_result_cmpl_partial : server_task_result {
             });
         };
         // We have to send an initial update to conform to openai behavior
-        if (first) {
+        if (first || is_progress) {
             add_delta({
                 {"role", "assistant"},
                 {"content", nullptr},
@@ -1121,16 +1150,20 @@ struct server_task_result_cmpl_partial : server_task_result {
         }
 
         if (!deltas.empty()) {
-            GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
+            auto & last_json = deltas[deltas.size() - 1];
+            GGML_ASSERT(last_json.at("choices").size() >= 1);
 
             if (prob_output.probs.size() > 0) {
-                deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
+                last_json.at("choices").at(0)["logprobs"] = json {
                     {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
                 };
             }
 
             if (timings.prompt_n >= 0) {
-                deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
+                last_json.push_back({"timings", timings.to_json()});
+            }
+            if (is_progress) {
+                last_json.push_back({"prompt_progress", progress.to_json()});
             }
         }
 
@@ -1404,6 +1437,7 @@ struct server_slot {
 
     // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
     int32_t n_prompt_tokens           = 0;
+    int32_t n_prompt_tokens_cache     = 0;
     int32_t n_prompt_tokens_processed = 0;
 
     // input prompt tokens
@@ -1456,7 +1490,9 @@ struct server_slot {
     void reset() {
         SLT_DBG(*this, "%s", "\n");
 
-        n_prompt_tokens    = 0;
+        n_prompt_tokens       = 0;
+        n_prompt_tokens_cache = 0;
+
         last_nl_pos        = 0;
         generated_text     = "";
         has_new_line       = false;
@@ -1547,6 +1583,8 @@ struct server_slot {
 
     result_timings get_timings() const {
         result_timings timings;
+        timings.cache_n = n_prompt_tokens_cache;
+
         timings.prompt_n = n_prompt_tokens_processed;
         timings.prompt_ms = t_prompt_processing;
         timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
@@ -2520,7 +2558,7 @@ struct server_context {
 
             slot.add_token(result);
             if (slot.params.stream) {
-                send_partial_response(slot, result);
+                send_partial_response(slot, result, false);
             }
         }
 
@@ -2712,13 +2750,24 @@ struct server_context {
         return true;
     }
 
-    void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
+    void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
         auto res = std::make_unique<server_task_result_cmpl_partial>();
 
-        res->id      = slot.id_task;
-        res->index   = slot.index;
-        res->content = tkn.text_to_send;
-        res->tokens  = { tkn.tok };
+        res->id    = slot.id_task;
+        res->index = slot.index;
+
+        if (is_progress) {
+            res->is_progress        = true;
+            res->progress.total     = slot.n_prompt_tokens;
+            res->progress.cache     = slot.n_prompt_tokens_cache;
+            res->progress.processed = slot.cache_tokens.size();
+            res->progress.time_ms   = (ggml_time_us() - slot.t_start_process_prompt / 1000);
+        } else {
+            res->content = tkn.text_to_send;
+            res->tokens  = { tkn.tok };
+
+            slot.update_chat_msg(res->oaicompat_msg_diffs);
+        }
 
         res->n_decoded           = slot.n_decoded;
         res->n_prompt_tokens     = slot.n_prompt_tokens;
@@ -2729,8 +2778,6 @@ struct server_context {
         res->oaicompat_model       = slot.params.oaicompat_model;
         res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
 
-        slot.update_chat_msg(res->oaicompat_msg_diffs);
-
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
             res->prob_output = tkn; // copy the token probs
@@ -3557,6 +3604,7 @@ struct server_context {
                             slot.n_past--;
                         }
 
+                        slot.n_prompt_tokens_cache     = slot.n_past;
                         slot.n_prompt_tokens_processed = 0;
                     }
 
@@ -3573,7 +3621,8 @@ struct server_context {
                         llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
 
                         // there is no common part left
-                        slot.n_past = 0;
+                        slot.n_past                = 0;
+                        slot.n_prompt_tokens_cache = 0;
                     }
 
                     SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@@ -3767,6 +3816,13 @@ struct server_context {
             n_batch = llama_n_batch(ctx);
 
             for (auto & slot : slots) {
+                // optionally send prompt processing progress
+                if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
+                    if (slot.params.stream && slot.params.return_progress) {
+                        send_partial_response(slot, {}, true);
+                    }
+                }
+
                 if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
                     continue; // continue loop of slots
                 }
index 22dbfdd9b4a8901091d8be5882320d1eaaa48ea4..53421d1b57351bbe2d7b1fbb9cc4b1dcdfc9546d 100644 (file)
@@ -402,3 +402,51 @@ def test_context_size_exceeded():
     assert server.n_ctx is not None
     assert server.n_slots is not None
     assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
+
+
+@pytest.mark.parametrize(
+    "n_batch,batch_count,reuse_cache",
+    [
+        (64, 15, False),
+        (64, 1, True),
+    ]
+)
+def test_return_progresssss(n_batch, batch_count, reuse_cache):
+    global server
+    server.n_batch = n_batch
+    server.n_ctx = 2048
+    server.n_slots = 1
+    server.start()
+    def make_cmpl_request():
+        return server.make_stream_request("POST", "/chat/completions", data={
+            "max_tokens": 10,
+            "messages": [
+                {"role": "user", "content": "This is a test" * 100},
+            ],
+            "stream": True,
+            "return_progress": True,
+        })
+    if reuse_cache:
+        # make a first request to populate the cache
+        res0 = make_cmpl_request()
+        for _ in res0:
+            pass # discard the output
+
+    res = make_cmpl_request()
+    last_progress = None
+    total_batch_count = 0
+    for data in res:
+        cur_progress = data.get("prompt_progress", None)
+        if cur_progress is None:
+            continue
+        if last_progress is not None:
+            assert cur_progress["total"] == last_progress["total"]
+            assert cur_progress["cache"] == last_progress["cache"]
+            assert cur_progress["processed"] > last_progress["processed"]
+        total_batch_count += 1
+        last_progress = cur_progress
+
+    assert last_progress is not None
+    assert last_progress["total"] > 0
+    assert last_progress["processed"] == last_progress["total"]
+    assert total_batch_count == batch_count