]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
spec : remove check rate (#19377)
authorSascha Rogmann <redacted>
Mon, 9 Feb 2026 13:30:50 +0000 (14:30 +0100)
committerGitHub <redacted>
Mon, 9 Feb 2026 13:30:50 +0000 (15:30 +0200)
* spec: remove parameter spec-ngram-check-rate

* spec : renamed statistics vars

* spec : add n_call_begin, n_call_accept

* spec : don't enable key-map-stats

common/arg.cpp
common/common.h
common/ngram-map.cpp
common/ngram-map.h
common/speculative.cpp
docs/speculative.md
tools/server/server-task.cpp

index 5fbc9022c02934a6d033508ec614fbb891729b7d..9c85696ebdbd85284df285152d22846dbab048f6 100644 (file)
@@ -3437,16 +3437,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.speculative.ngram_size_m = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
-    add_opt(common_arg(
-        {"--spec-ngram-check-rate"}, "N",
-        string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
-        [](common_params & params, int value) {
-            if (value < 1) {
-                throw std::invalid_argument("ngram check rate must be at least 1");
-            }
-            params.speculative.ngram_check_rate = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--spec-ngram-min-hits"}, "N",
         string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
index 398ebb09601c3745543bc98c72ee8da8f5db05d1..b284244530a19cdf40d604a6eb4ae3b12a15fd0c 100644 (file)
@@ -269,7 +269,6 @@ struct common_params_speculative {
 
     uint16_t ngram_size_n     = 12; // ngram size for lookup
     uint16_t ngram_size_m     = 48; // mgram size for speculative tokens
-    uint16_t ngram_check_rate =  1; // check rate for ngram lookup
     uint16_t ngram_min_hits   =  1; // minimum hits at ngram/mgram lookup for mgram to be proposed
 
     std::shared_ptr<common_ngram_mod> ngram_mod;
index c5b8fc75ed8bf6dfc2e3898758d1fc1b6fd81c82..2b876a6e991b444e073f332c4605f306996bb0e2 100644 (file)
@@ -231,10 +231,9 @@ void common_ngram_map_draft(common_ngram_map & map,
         GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
     }
 
-    // Only check every check_rate tokens to save compute
-    // i.e., perform check if (cur_len - idx_last_check) >= check_rate
-    if (map.idx_last_check + map.check_rate > cur_len) {
-        return;
+    if (map.idx_last_check  > cur_len) {
+        // Should not happen because of common_ngram_map_begin().
+        GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
     }
     map.idx_last_check = cur_len;
 
index 9668bd5a7c5458170e3b51496c3a11fab7747ab6..41b95304497b4bfc829466d6053ae2b4ce5c38cf 100644 (file)
@@ -24,7 +24,6 @@
 struct common_ngram_simple_config {
     uint16_t   size_ngram;      // size of n-grams to lookup in self-mode
     uint16_t   size_mgram;      // size of m-grams to draft in self-mode
-    uint16_t   check_rate;      // check for speculative decoding without draft model for each check_rate token
 };
 
 // Searches for a n-gram in the history and checks whether a draft sequence should be generated.
@@ -66,15 +65,14 @@ struct common_ngram_map {
     bool key_only;       // true if only key n-grams are used, no values.
 
     std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
-    uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
     uint16_t min_hits;   // minimum number of key hits to consider a draft
 
-    bool     show_key_map_stats = false; // true, if statitics of the key_map should be printed.
+    bool     show_key_map_stats = false; // true, if statistics of the key_map should be printed.
 
     common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
-                     uint16_t check_rate, uint16_t min_hits)
+                     uint16_t min_hits)
         : size_key(sz_key), size_value(sz_value), key_only(only_keys),
-          check_rate(check_rate), min_hits(min_hits) {
+          min_hits(min_hits) {
         key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used
     }
 
index 84d2556cebaa95ef371d31e362c7c13ece626a71..3e68c38e49cc972d06f0cb1568aff054b781038c 100644 (file)
@@ -113,13 +113,14 @@ static bool common_speculative_are_compatible(
 struct common_speculative_state {
     const enum common_speculative_type type;
 
-    // TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens
-    // TODO: add n_call_begin, n_call_accept
-    size_t drafts_call_count       = 0; // number of times this implementation was called.
-    size_t drafts_generated_count  = 0; // number of times a draft or part was generated by this implementation.
-    size_t drafts_accepted_count   = 0; // number of times a draft or part was accepted by the target model.
-    size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation.
-    size_t drafts_accepted_tokens  = 0; // number of tokens accepted by the target model.
+    size_t n_call_begin  = 0; // number of times this implementation was called for refresh.
+    size_t n_call_draft  = 0; // number of times this implementation was called for generation.
+    size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
+
+    size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
+    size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
+    size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
+    size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
 
     // TODO: track performance of most recent calls
     const bool gen_perf = true; // whether to generate performance stats.
@@ -465,8 +466,6 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
 struct common_speculative_state_ngram_simple : public common_speculative_state {
     common_ngram_simple_config config;
 
-    uint16_t check_id = 0; // used to control the frequency of generating drafts
-
     common_speculative_state_ngram_simple(
             enum common_speculative_type type,
             common_ngram_simple_config config)
@@ -481,11 +480,6 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
             const llama_tokens & prompt_tgt,
             llama_token id_last,
             llama_tokens & result) override {
-        ++check_id;
-        if (check_id < config.check_rate) {
-            return;
-        }
-        check_id = 0;
 
         result = common_ngram_simple_draft(config, prompt_tgt, id_last);
         GGML_UNUSED(params);
@@ -752,10 +746,9 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c
     uint16_t size_key   = config.params.ngram_size_n;
     uint16_t size_value = config.params.ngram_size_m;
     bool     key_only   = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
-    uint16_t check_rate = config.params.ngram_check_rate;
     uint16_t min_hits   = config.params.ngram_min_hits;
 
-    return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
+    return common_ngram_map(size_key, size_value, key_only, min_hits);
 }
 
 static common_speculative_state_ngram_cache create_state_ngram_cache(
@@ -931,12 +924,10 @@ common_speculative * common_speculative_init(
 
                 uint16_t ngram_size_key   = ngram_map.size_key;
                 uint16_t mgram_size_value = ngram_map.size_value;
-                uint16_t check_rate       = ngram_map.check_rate;
 
                 auto config_simple = common_ngram_simple_config {
                     /* .size_ngram      = */ ngram_size_key,
-                    /* .size_mgram      = */ mgram_size_value,
-                    /* .check_rate      = */ check_rate
+                    /* .size_mgram      = */ mgram_size_value
                 };
                 auto state = std::make_unique<common_speculative_state_ngram_simple>(
                     /* .type            = */ config.type,
@@ -997,6 +988,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
     for (auto & impl : spec->impls) {
         common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
         impl->begin(prompt);
+        impl->n_call_begin++;
     }
 }
 
@@ -1013,17 +1005,17 @@ llama_tokens common_speculative_draft(
         {
             common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
             impl->draft(params, prompt_tgt, id_last, result);
-            impl->drafts_call_count++;
+            impl->n_call_draft++;
         }
 
         if (!result.empty()) {
             LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
                     common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
-                    impl.get()->drafts_call_count, result.size());
+                    impl.get()->n_call_draft, result.size());
 
             spec->curr_impl = impl.get(); // set current implementation for stats
-            impl->drafts_generated_count++;
-            impl->drafts_generated_tokens += result.size();
+            impl->n_gen_drafts++;
+            impl->n_gen_tokens += result.size();
 
             break; // We have a draft, so break out of the loop and return it.
         }
@@ -1044,11 +1036,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
     {
         common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
         if (n_accepted > 0) {
-            impl->drafts_accepted_count++;
-            impl->drafts_accepted_tokens += n_accepted;
+            impl->n_acc_drafts++;
+            impl->n_acc_tokens += n_accepted;
         }
 
         impl->accept(n_accepted);
+        impl->n_call_accept++;
     }
 }
 
@@ -1069,13 +1062,13 @@ void common_speculative_print_stats(const common_speculative * spec) {
             str_perf = "";
         }
 
-        LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
+        LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
                 common_speculative_type_to_str(impl->type).c_str(),
-                impl->drafts_call_count,
-                impl->drafts_generated_count,
-                impl->drafts_accepted_count,
-                impl->drafts_generated_tokens,
-                impl->drafts_accepted_tokens,
+                impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
+                impl->n_gen_drafts,
+                impl->n_acc_drafts,
+                impl->n_gen_tokens,
+                impl->n_acc_tokens,
                 str_perf.c_str());
     }
 }
index 03afab5b41e1fc77ef69ef709d3515d4d915171c..29da332875f0a2b5595f739c481b6dd54d82a35c 100644 (file)
@@ -119,8 +119,6 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
                                         of lookup n-gram (default: 12)
 --spec-ngram-size-m N                   ngram size M for ngram-simple/ngram-map speculative decoding, length
                                         of draft m-gram (default: 48)
---spec-ngram-check-rate N               ngram check rate for ngram-simple/ngram-map speculative decoding
-                                        (default: 1)
 --spec-ngram-min-hits N                 minimum hits for ngram-map speculative decoding (default: 1)
 ```
 
@@ -153,10 +151,6 @@ Sets the size M of the draft m-gram for n-gram map based speculative decoding.
 The m-gram size determines how many tokens to draft when a match is found.
 Larger values can provide more speedup but may reduce acceptance rate.
 
-### `--spec-ngram-check-rate R`
-
-This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
-
 ### `--spec-ngram-min-hits H`
 
 This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
@@ -175,7 +169,12 @@ draft acceptance rate = 0.70312 (   90 accepted /   128 generated)
 statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
 ```
 
-- `#calls`: number of calls of this implementations
+```
+statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms
+```
+
+
+- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations
 - `#gen drafts`: number of drafts generated by this implementation
 - `#acc drafts`: number of drafts accepted (partially) by the main model
 - `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
index 2d25db63b74e13be0146d03e1646ea0b8e47fd6d..a137427c69c02d4a91df8a00f82bcc49addd5901 100644 (file)
@@ -80,7 +80,6 @@ json task_params::to_json(bool only_metrics) const {
             {"speculative.type",          common_speculative_type_to_str(speculative.type)},
             {"speculative.ngram_size_n",  speculative.ngram_size_n},
             {"speculative.ngram_size_m",  speculative.ngram_size_m},
-            {"speculative.ngram_c_rate",  speculative.ngram_check_rate},
             {"speculative.ngram_m_hits",  speculative.ngram_min_hits},
             {"timings_per_token",         timings_per_token},
             {"post_sampling_probs",       post_sampling_probs},
@@ -144,7 +143,6 @@ json task_params::to_json(bool only_metrics) const {
         {"speculative.type",          common_speculative_type_to_str(speculative.type)},
         {"speculative.ngram_size_n",  speculative.ngram_size_n},
         {"speculative.ngram_size_m",  speculative.ngram_size_m},
-        {"speculative.ngram_c_rate",  speculative.ngram_check_rate},
         {"speculative.ngram_m_hits",  speculative.ngram_min_hits},
         {"timings_per_token",         timings_per_token},
         {"post_sampling_probs",       post_sampling_probs},
@@ -257,12 +255,10 @@ task_params server_task::params_from_json_cmpl(
 
     params.speculative.ngram_size_n     = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
     params.speculative.ngram_size_m     = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
-    params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate);
     params.speculative.ngram_min_hits   = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
 
     params.speculative.ngram_size_n     = std::max(std::min(1, (int) params.speculative.ngram_size_n),     1024);
     params.speculative.ngram_size_m     = std::max(std::min(1, (int) params.speculative.ngram_size_m),     1024);
-    params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024);
     params.speculative.ngram_min_hits   = std::max(std::min(1, (int) params.speculative.ngram_min_hits),   1024);
 
     // Use OpenAI API logprobs only if n_probs wasn't provided