]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: Add "tokens per second" information in the backend (#10548)
authorhaopeng <redacted>
Mon, 2 Dec 2024 13:45:54 +0000 (21:45 +0800)
committerGitHub <redacted>
Mon, 2 Dec 2024 13:45:54 +0000 (14:45 +0100)
* add cmake rvv support

* add timings

* remove space

* update readme

* fix

* fix code

* remove empty line

* add test

---------

Co-authored-by: Xuan Son Nguyen <redacted>
common/common.h
examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_chat_completion.py
examples/server/utils.hpp

index 9b1508a15fb43840bca968bc2374dc2e674871e6..0373fd3ead49ee60a3fb6b01f1e963283861a860 100644 (file)
@@ -133,6 +133,7 @@ struct common_params_sampling {
     bool    penalize_nl        = false; // consider newlines as a repeatable token
     bool    ignore_eos         = false;
     bool    no_perf            = false; // disable performance metrics
+    bool    timing_per_token   = false;
 
     std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
 
index 877768c8b0bd2f901317a6f1c55efa541011352a..45ffb547fcbccb6b9fcd98290dcf68e960f3a4d4 100644 (file)
@@ -416,6 +416,8 @@ node index.js
 
     `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
 
+    `timings_per_token`: Include prompt processing and text generation speed information in each response.  Default: `false`
+
 **Response format**
 
 - Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.
index 1c765f0ea22ecce854a1a5af1ca8e767c91ccd09..8eca14b86d517589b922ab03c295b587054810ef 100644 (file)
@@ -177,6 +177,8 @@ struct server_slot {
     bool stopped_word   = false;
     bool stopped_limit  = false;
 
+    bool timings_per_token = false;
+
     bool oaicompat = false;
 
     std::string oaicompat_model;
@@ -882,6 +884,8 @@ struct server_context {
             slot.oaicompat_model = "";
         }
 
+        slot.timings_per_token       = json_value(data, "timings_per_token",  false);
+
         slot.params.stream           = json_value(data, "stream",             false);
         slot.params.cache_prompt     = json_value(data, "cache_prompt",       true);
         slot.params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
@@ -1279,6 +1283,7 @@ struct server_context {
             {"speculative.n_max",         slot.params.speculative.n_max},
             {"speculative.n_min",         slot.params.speculative.n_min},
             {"speculative.p_min",         slot.params.speculative.p_min},
+            {"timings_per_token",         slot.timings_per_token},
         };
     }
 
@@ -1336,6 +1341,10 @@ struct server_context {
             res.data["model"] = slot.oaicompat_model;
         }
 
+        if (slot.timings_per_token) {
+            res.data["timings"] = slot.get_formated_timings();
+        }
+
         queue_results.send(res);
     }
 
@@ -2274,12 +2283,17 @@ struct server_context {
                 common_sampler_accept(slot.smpl, id, true);
 
                 slot.n_decoded += 1;
+
+                const int64_t t_current = ggml_time_us();
+
                 if (slot.n_decoded == 1) {
-                    slot.t_start_generation = ggml_time_us();
+                    slot.t_start_generation = t_current;
                     slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
                     metrics.on_prompt_eval(slot);
                 }
 
+                slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
+
                 completion_token_output result;
                 result.tok = id;
 
index 1048d6fcaf500ad0bd20747f6b277afd2cb75400..8a439f9ef0f29d5807e741824022187a24690c98 100644 (file)
@@ -146,3 +146,20 @@ def test_invalid_chat_completion_req(messages):
     })
     assert res.status_code == 400 or res.status_code == 500
     assert "error" in res.body
+
+
+def test_chat_completion_with_timings_per_token():
+    global server
+    server.start()
+    res = server.make_stream_request("POST", "/chat/completions", data={
+        "max_tokens": 10,
+        "messages": [{"role": "user", "content": "test"}],
+        "stream": True,
+        "timings_per_token": True,
+    })
+    for data in res:
+        assert "timings" in data
+        assert "prompt_per_second" in data["timings"]
+        assert "predicted_per_second" in data["timings"]
+        assert "predicted_n" in data["timings"]
+        assert data["timings"]["predicted_n"] <= 10
index 1665e9dc37db6bec86639cdd6d4a6224af5ab17a..e4451532c9d0cedbaf7a5c277c58aeab0b76edef 100644 (file)
@@ -650,6 +650,10 @@ static json format_final_response_oaicompat(const json & request, const json & r
         res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
     }
 
+    if (result.contains("timings")) {
+        res.push_back({"timings", json_value(result, "timings", json::object())});
+    }
+
     return res;
 }
 
@@ -740,6 +744,11 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
         {"model",   modelname},
         {"object",  "chat.completion.chunk"}
     };
+
+    if (result.contains("timings")) {
+        ret.push_back({"timings", json_value(result, "timings", json::object())});
+    }
+
     if (!finish_reason.empty()) {
         int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
         int num_prompt_tokens    = json_value(result, "tokens_evaluated", 0);