]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : include speculative decoding stats when timings_per_token is enabled (#12603)
authorBenson Wong <redacted>
Fri, 28 Mar 2025 08:05:44 +0000 (01:05 -0700)
committerGitHub <redacted>
Fri, 28 Mar 2025 08:05:44 +0000 (10:05 +0200)
* Include speculative decoding stats when timings_per_token is true

New fields added to the `timings` object:

  - draft_n           : number of draft tokens generated
  - draft_accepted_n  : number of draft tokens accepted
  - draft_accept_ratio: ratio of accepted/generated

* Remove redundant draft_accept_ratio var

* add draft acceptance rate to server console output

examples/server/server.cpp

index 77dd316d9d68911969601ee2c09a277e5d138b6d..17a292da153c1f4f2dcb5b1df50c2419cae01f86 100644 (file)
@@ -489,8 +489,12 @@ struct result_timings {
     double predicted_per_token_ms;
     double predicted_per_second;
 
+    // Optional speculative metrics - only included when > 0
+    int32_t draft_n = 0;
+    int32_t draft_n_accepted = 0;
+
     json to_json() const {
-        return {
+        json base = {
             {"prompt_n",               prompt_n},
             {"prompt_ms",              prompt_ms},
             {"prompt_per_token_ms",    prompt_per_token_ms},
@@ -501,6 +505,13 @@ struct result_timings {
             {"predicted_per_token_ms", predicted_per_token_ms},
             {"predicted_per_second",   predicted_per_second},
         };
+
+        if (draft_n > 0) {
+            base["draft_n"] = draft_n;
+            base["draft_n_accepted"] = draft_n_accepted;
+        }
+
+        return base;
     }
 };
 
@@ -1299,6 +1310,10 @@ struct server_slot {
 
     std::function<void(int)> callback_on_release;
 
+    // Speculative decoding stats
+    int32_t n_draft_total = 0;      // Total draft tokens generated
+    int32_t n_draft_accepted = 0;   // Draft tokens actually accepted
+
     void reset() {
         SLT_DBG(*this, "%s", "\n");
 
@@ -1315,6 +1330,10 @@ struct server_slot {
 
         generated_tokens.clear();
         generated_token_probs.clear();
+
+        // clear speculative decoding stats
+        n_draft_total = 0;
+        n_draft_accepted = 0;
     }
 
     bool is_non_causal() const {
@@ -1381,6 +1400,12 @@ struct server_slot {
         timings.predicted_per_token_ms = t_token_generation / n_decoded;
         timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
 
+        // Add speculative metrics
+        if (n_draft_total > 0) {
+            timings.draft_n = n_draft_total;
+            timings.draft_n_accepted = n_draft_accepted;
+        }
+
         return timings;
     }
 
@@ -1428,6 +1453,15 @@ struct server_slot {
                 t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
                 t_token_generation, n_decoded, t_gen, n_gen_second,
                 t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
+
+        if (n_draft_total > 0) {
+            const float draft_ratio = (float) n_draft_accepted / n_draft_total;
+            SLT_INF(*this,
+                    "\n"
+                    "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
+                    draft_ratio, n_draft_accepted, n_draft_total
+            );
+        }
     }
 
     json to_json() const {
@@ -3290,6 +3324,9 @@ struct server_context {
 
                 llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
 
+                // keep track of total number of tokens generated in the draft
+                slot.n_draft_total += draft.size();
+
                 // ignore small drafts
                 if (slot.params.speculative.n_min > (int) draft.size()) {
                     SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
@@ -3315,6 +3352,9 @@ struct server_context {
                 slot.n_past    += ids.size();
                 slot.n_decoded += ids.size();
 
+                // update how many tokens out of draft was accepted
+                slot.n_draft_accepted += ids.size() - 1;
+
                 slot.cache_tokens.push_back(id);
                 slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);