double predicted_per_token_ms;
double predicted_per_second;
+ // Optional speculative metrics - only included when > 0
+ int32_t draft_n = 0;
+ int32_t draft_n_accepted = 0;
+
json to_json() const {
- return {
+ json base = {
{"prompt_n", prompt_n},
{"prompt_ms", prompt_ms},
{"prompt_per_token_ms", prompt_per_token_ms},
{"predicted_per_token_ms", predicted_per_token_ms},
{"predicted_per_second", predicted_per_second},
};
+
+ if (draft_n > 0) {
+ base["draft_n"] = draft_n;
+ base["draft_n_accepted"] = draft_n_accepted;
+ }
+
+ return base;
}
};
std::function<void(int)> callback_on_release;
+ // Speculative decoding stats
+ int32_t n_draft_total = 0; // Total draft tokens generated
+ int32_t n_draft_accepted = 0; // Draft tokens actually accepted
+
void reset() {
SLT_DBG(*this, "%s", "\n");
generated_tokens.clear();
generated_token_probs.clear();
+
+ // clear speculative decoding stats
+ n_draft_total = 0;
+ n_draft_accepted = 0;
}
bool is_non_causal() const {
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
+ // Add speculative metrics
+ if (n_draft_total > 0) {
+ timings.draft_n = n_draft_total;
+ timings.draft_n_accepted = n_draft_accepted;
+ }
+
return timings;
}
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
+
+ if (n_draft_total > 0) {
+ const float draft_ratio = (float) n_draft_accepted / n_draft_total;
+ SLT_INF(*this,
+ "\n"
+ "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
+ draft_ratio, n_draft_accepted, n_draft_total
+ );
+ }
}
json to_json() const {
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
+ // keep track of total number of tokens generated in the draft
+ slot.n_draft_total += draft.size();
+
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
+ // update how many tokens out of draft was accepted
+ slot.n_draft_accepted += ids.size() - 1;
+
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);