]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
spec : add self‑speculative decoding (no draft model required) + refactor (#18471)
authorSascha Rogmann <redacted>
Wed, 28 Jan 2026 17:42:42 +0000 (18:42 +0100)
committerGitHub <redacted>
Wed, 28 Jan 2026 17:42:42 +0000 (19:42 +0200)
* server: introduce self-speculative decoding

* server: moved self-call into speculative.cpp

* can_speculate() includes self-speculation

Co-authored-by: Georgi Gerganov <redacted>
* server: can_speculate() tests self-spec

* server: replace can_speculate() with slot.can_speculate()

Co-authored-by: Sigbjørn Skjæret <redacted>
* common: use %zu format specifier for size_t in logging

Co-authored-by: Sigbjørn Skjæret <redacted>
* server: can_speculate() requires a task instance

* common: ngram map, config self-speculative decoding

* common: add enum common_speculative_type

* common: add vector of speculative states

* common: add option --spec-draftless

* server: cleanup (remove slot.batch_spec, rename)

* common: moved self-spec impl to ngram-map

* common: cleanup (use common_speculative_state_draft)

* spec : refactor

* cont : naming

* spec: remove --spec-config

* doc: (draftless) speculative decoding

* common: print performance in spec decoding

* minor : cleanup

* common : better names

* minor : cleanup + fix build

* minor: comments

* CODEOWNERS: add common/ngram-map.* (#18471)

* common : rename speculative.draftless_type -> speculative.type

* ngram-map : fix uninitialized values

* ngram-map : take into account the input can become shorter

* ngram-map : revert len check for now

* arg : change `--spec-draftless` -> `--spec-type`

* spec : add common_speculative_state::accept()

* spec : refactor + add common_speculative_begin()

* spec : fix begin() call with mtmd

* spec : additional refactor + remove common_speculative_params

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
19 files changed:
CODEOWNERS
common/CMakeLists.txt
common/arg.cpp
common/common.cpp
common/common.h
common/ngram-cache.cpp
common/ngram-cache.h
common/ngram-map.cpp [new file with mode: 0644]
common/ngram-map.h [new file with mode: 0644]
common/speculative.cpp
common/speculative.h
docs/speculative.md [new file with mode: 0644]
examples/lookup/lookup-create.cpp
examples/lookup/lookup-stats.cpp
examples/lookup/lookup.cpp
examples/speculative-simple/speculative-simple.cpp
examples/speculative/speculative.cpp
tools/server/server-context.cpp
tools/server/server-task.cpp

index 6086abb564d7215059dcded14c70cc88953e0864..e573a3d2e631a2edddb6413214642cea36288904 100644 (file)
@@ -18,6 +18,7 @@
 /common/jinja/                          @ngxson @CISC @aldehir
 /common/llguidance.*                    @ggerganov
 /common/log.*                           @ggerganov
+/common/ngram-map.*                     @srogmann
 /common/peg-parser.*                    @aldehir
 /common/sampling.*                      @ggerganov
 /common/speculative.*                   @ggerganov
index ae02c0bd77f33300e925d2de5a3c7f3ed624e275..3bc7bc6210bc9b025af9744a1624fe9e833223e5 100644 (file)
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
     log.h
     ngram-cache.cpp
     ngram-cache.h
+    ngram-map.cpp
+    ngram-map.h
     peg-parser.cpp
     peg-parser.h
     preset.cpp
index cd3a1b639704092eede2afba8d75bf05aa236299..a685c418bfcff6fd107594101cbc415c969f1ddc 100644 (file)
@@ -6,6 +6,7 @@
 #include "json-schema-to-grammar.h"
 #include "log.h"
 #include "sampling.h"
+#include "speculative.h"
 #include "preset.h"
 
 // fix problem with std::min and std::max
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
             params.mmproj = res.mmproj;
         }
         // only download mmproj if the current example is using it
-        for (auto & ex : mmproj_examples) {
+        for (const auto & ex : mmproj_examples) {
             if (ctx_arg.ex == ex) {
                 common_params_handle_model(params.mmproj,    params.hf_token, params.offline);
                 break;
             }
         }
-        common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
-        common_params_handle_model(params.vocoder.model,     params.hf_token, params.offline);
+        common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
+        common_params_handle_model(params.vocoder.model,           params.hf_token, params.offline);
     }
 
     // model is required (except for server)
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"-lcs", "--lookup-cache-static"}, "FNAME",
         "path to static lookup cache to use for lookup decoding (not updated by generation)",
         [](common_params & params, const std::string & value) {
-            params.lookup_cache_static = value;
+            params.speculative.lookup_cache_static = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_LOOKUP}));
+    ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"-lcd", "--lookup-cache-dynamic"}, "FNAME",
         "path to dynamic lookup cache to use for lookup decoding (updated by generation)",
         [](common_params & params, const std::string & value) {
-            params.lookup_cache_dynamic = value;
+            params.speculative.lookup_cache_dynamic = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_LOOKUP}));
+    ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"-c", "--ctx-size"}, "N",
         string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
@@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
         "Same as --hf-repo, but for the draft model (default: unused)",
         [](common_params & params, const std::string & value) {
-            params.speculative.model.hf_repo = value;
+            params.speculative.mparams_dft.hf_repo = value;
         }
     ).set_env("LLAMA_ARG_HFD_REPO"));
     add_opt(common_arg(
@@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"-md", "--model-draft"}, "FNAME",
         "draft model for speculative decoding (default: unused)",
         [](common_params & params, const std::string & value) {
-            params.speculative.model.path = value;
+            params.speculative.mparams_dft.path = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
     add_opt(common_arg(
@@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.speculative.replacements.push_back({ tgt, dft });
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
+    add_opt(common_arg(
+        {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
+        string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
+            common_speculative_type_to_str(params.speculative.type).c_str()),
+        [](common_params & params, const std::string & value) {
+            if (value == "none") {
+                params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
+            } else if (value == "ngram-cache") {
+                params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
+            } else if (value == "ngram-simple") {
+                params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
+            } else if (value == "ngram-map-k") {
+                params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
+            } else if (value == "ngram-map-k4v") {
+                params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
+            } else {
+                throw std::invalid_argument("unknown speculative decoding type without draft model");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--spec-ngram-size-n"}, "N",
+        string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
+        [](common_params & params, int value) {
+            if (value < 1 || value > 1024) {
+                throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
+            }
+            params.speculative.ngram_size_n = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--spec-ngram-size-m"}, "N",
+        string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
+        [](common_params & params, int value) {
+            if (value < 1 || value > 1024) {
+                throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
+            }
+            params.speculative.ngram_size_m = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--spec-ngram-check-rate"}, "N",
+        string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
+        [](common_params & params, int value) {
+            if (value < 1) {
+                throw std::invalid_argument("ngram check rate must be at least 1");
+            }
+            params.speculative.ngram_check_rate = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--spec-ngram-min-hits"}, "N",
+        string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
+        [](common_params & params, int value) {
+            if (value < 1) {
+                throw std::invalid_argument("ngram min hits must be at least 1");
+            }
+            params.speculative.ngram_min_hits = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"-ctkd", "--cache-type-k-draft"}, "TYPE",
         string_format(
@@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params) {
             params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
             params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
-            params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
-            params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
+            params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
+            params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
             params.port = 8012;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
@@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params) {
             params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
             params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
-            params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
-            params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
+            params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
+            params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
             params.port = 8012;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
index 26250abb6c815d816cf79c779f3a4fe234dd3443..3aa396127ceadcad0d197d5574792d6c7a402f9a 100644 (file)
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
     if (params.fit_params) {
         LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
         llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
-            params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
+            params.tensor_split,
+            params.tensor_buft_overrides.data(),
+            params.fit_params_target.data(),
+            params.fit_params_min_ctx,
             params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
     }
 
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
     return pimpl->lora;
 }
 
-void common_init_result::free_context() {
-    pimpl->context.reset();
-}
-
 common_init_result_ptr common_init_from_params(common_params & params) {
     common_init_result_ptr res(new common_init_result(params));
 
index 21c11f457d4a19c168b50465e36232f262c56fbb..fd3ab8cd18018f6f98bffd17c8eea22d505bcef7 100644 (file)
@@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
     COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA    = 1 << 11,
 };
 
+enum common_speculative_type {
+    COMMON_SPECULATIVE_TYPE_NONE,          // no speculative decoding
+    COMMON_SPECULATIVE_TYPE_DRAFT,         // draft model
+    COMMON_SPECULATIVE_TYPE_EAGLE3,        // eagle draft model
+    COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,  // simple self-speculative decoding
+    COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,   // self-speculative decoding with n-gram keys only
+    COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
+    COMMON_SPECULATIVE_TYPE_NGRAM_CACHE,   // self-speculative decoding with 3-level n-gram cache
+    COMMON_SPECULATIVE_TYPE_COUNT          // number of types, unknown type
+};
 
 // sampling parameters
 struct common_params_sampling {
@@ -243,16 +253,35 @@ struct common_params_model {
 };
 
 struct common_params_speculative {
-    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+    common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
 
-    int32_t n_ctx        =     0; // draft context size
-    int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
-    int32_t n_min        =     0; // minimum number of draft tokens to use for speculative decoding
-    int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
-    float   p_split      =  0.1f; // speculative decoding split probability
-    float   p_min        = 0.75f; // minimum speculative decoding probability (greedy)
-    std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
-    std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
+    // general-purpose speculative decoding parameters
+
+    int32_t n_max   = 16; // maximum number of tokens to draft during speculative decoding
+    int32_t n_min   = 0; // minimum number of draft tokens to use for speculative decoding
+    float   p_split = 0.1f; // speculative decoding split probability
+    float   p_min   = 0.75f; // minimum speculative decoding probability (greedy)
+
+    // ngram-based speculative decoding
+
+    uint16_t ngram_size_n     = 12; // ngram size for lookup
+    uint16_t ngram_size_m     = 48; // mgram size for speculative tokens
+    uint16_t ngram_check_rate =  1; // check rate for ngram lookup
+    uint16_t ngram_min_hits   =  1; // minimum hits at ngram/mgram lookup for mgram to be proposed
+
+    std::string lookup_cache_static;  // path of static ngram cache file for lookup decoding           // NOLINT
+    std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding          // NOLINT
+
+    // draft-model speculative decoding
+
+    struct common_params_model mparams_dft;
+
+    llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
+
+    llama_context_params cparams_dft; // these are the parameters for the draft llama_context
+
+    int32_t n_ctx        = 0;  // draft context size
+    int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
 
     ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
     ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@@ -260,7 +289,14 @@ struct common_params_speculative {
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
 
-    struct common_params_model model;
+    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+
+    std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
+    std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
+
+    bool has_dft() const {
+        return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
+    }
 };
 
 struct common_params_vocoder {
@@ -378,8 +414,6 @@ struct common_params {
     std::string path_prompt_cache    = ""; // path to file for saving/loading prompt eval state             // NOLINT
     std::string input_prefix         = ""; // string to prefix user inputs with                             // NOLINT
     std::string input_suffix         = ""; // string to suffix user inputs with                             // NOLINT
-    std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding           // NOLINT
-    std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
     std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
 
     // llama-debug specific options
@@ -575,10 +609,6 @@ struct common_params {
     // return false from callback to abort model loading or true to continue
     llama_progress_callback load_progress_callback = NULL;
     void *                  load_progress_callback_user_data = NULL;
-
-    bool has_speculative() const {
-        return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
-    }
 };
 
 // call once at the start of a program if it uses libcommon
@@ -714,8 +744,6 @@ struct common_init_result {
 
     std::vector<llama_adapter_lora_ptr> & lora();
 
-    void free_context();
-
 private:
     struct impl;
     std::unique_ptr<impl> pimpl;
index d1a4d84c40f1c742ece7213139b1325ac62a8edc..dce54b3647490bf606e2408bf924e8c10a50c8e5 100644 (file)
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
             break;
         }
 
-        LOG(" - draft candidate: token=%d\n", drafted_token);
+        LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
         draft.push_back(drafted_token);
     }
 }
 
-void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
+void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
     std::ofstream file_out(filename, std::ios::binary);
     for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
         const common_ngram      ngram        = item.first;
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
             file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
         }
     }
-
 }
 
-common_ngram_cache common_ngram_cache_load(std::string & filename) {
+common_ngram_cache common_ngram_cache_load(const std::string & filename) {
     std::ifstream hashmap_file(filename, std::ios::binary);
     if (!hashmap_file) {
         throw std::ifstream::failure("Unable to open file " + filename);
index dfe012abe493dc03d2cda367572adc96d61f7b7c..6e7cfea966dfb4847ac9a03932c58504c3ad88ef 100644 (file)
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
 // Save an ngram cache to a file.
 // ngram_cache: the ngram cache to save.
 // filename:    the path under which to save the ngram cache.
-void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
+void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
 
 // Load an ngram cache saved with common_ngram_cache_save.
 // filename: the path from which to load the ngram cache.
 // returns:  an ngram cache containing the information saved to filename.
-common_ngram_cache common_ngram_cache_load(std::string & filename);
+common_ngram_cache common_ngram_cache_load(const std::string & filename);
 
 // Merge two ngram caches.
 // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp
new file mode 100644 (file)
index 0000000..930e7a3
--- /dev/null
@@ -0,0 +1,367 @@
+#include "common.h"
+#include "log.h"
+#include "ngram-map.h"
+
+#include <cinttypes>
+#include <cstdint>
+#include <cstdio>
+#include <sstream>
+
+// n-gram simple
+//
+
+/**
+ * Perform speculative generation using the model's own token history.
+ * Searches for a matching pattern in the token history and returns draft tokens.
+ *
+ * @param state     Current state of this implementation
+ * @param tokens    Token history to search in
+ * @param sampled   Last sampled token
+ * @return Vector of draft tokens, empty if no matching pattern is found
+ */
+llama_tokens common_ngram_simple_draft(
+        common_ngram_simple_state & state,
+        const llama_tokens & tokens, llama_token sampled) {
+
+    // Simple implementation of self-speculative decoding without a draft model.
+    //
+    const size_t cur_len = tokens.size();
+    // Only check every check_rate tokens to save compute
+    // i.e., perform check if (cur_len - idx_last_check) >= check_rate
+    if (state.idx_last_check + state.config.check_rate > cur_len) {
+        llama_tokens draft_tokens;
+        return draft_tokens;
+    }
+
+    size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
+    size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
+
+    // vector for tokens we want to verify.
+    // return empty vector if there is no match.
+    llama_tokens draft_tokens;
+
+    // We need at least n_draft_min + n_draft_max + 1 tokens.
+    if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
+        return draft_tokens;
+    }
+
+    // pattern search
+    llama_tokens pattern;
+    pattern.reserve(n_draft_min);
+    for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
+        pattern.push_back(tokens[j]);
+    }
+    pattern.push_back(sampled); // add the last token to the pattern
+
+    // We do a search in the token history.
+    state.idx_last_check = cur_len;
+
+    size_t match_pos = 0; // we ignore position 0, position 0 == no match
+                          // search backwards, but skip the current match (we are currently there)
+    for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
+        bool match = true;
+        for (size_t k = 0; k < pattern.size(); ++k) {
+            if (tokens[j + k] != pattern[k]) {
+                match = false;
+                break;
+            }
+        }
+        if (match) {
+            match_pos = j;
+            break;
+        }
+    }
+    if (match_pos == 0) {
+        return draft_tokens;
+    }
+
+    const size_t copy_max = std::min(
+            n_draft_max,
+            cur_len - (match_pos + n_draft_min)
+            );
+    if (copy_max < n_draft_min) {
+        return draft_tokens;
+    }
+    LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
+            __func__, cur_len,
+            match_pos, pattern.size(), copy_max);
+
+    draft_tokens.reserve(copy_max);
+    for (size_t j = 0; j < copy_max; ++j) {
+        draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
+    }
+    return draft_tokens;
+}
+
+
+// n-gram map
+//
+
+// maximum number of counted values of a ngram map value.
+#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
+
+static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
+
+void common_ngram_map_draft(common_ngram_map & map,
+        const llama_tokens & inp, llama_token sampled,
+        llama_tokens & draft) {
+    // reset last key and value.
+    map.last_draft_created   = false;
+    map.last_draft_key_idx   = 0;
+    map.last_draft_value_idx = 0;
+
+    const size_t cur_len = inp.size();
+    const uint16_t n = map.size_key;
+    const uint16_t m = map.size_value;
+    if (cur_len < static_cast<size_t>(2 * n + m)) {
+        return;
+    }
+
+    // Only check every check_rate tokens to save compute
+    // i.e., perform check if (cur_len - idx_last_check) >= check_rate
+    if (map.idx_last_check + map.check_rate > cur_len) {
+        return;
+    }
+    map.idx_last_check = cur_len;
+
+    // search pattern, the key n-gram
+    std::vector<llama_token> key_tokens;
+    key_tokens.reserve(n);
+    for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
+        key_tokens.push_back(inp[j]);
+    }
+    key_tokens.push_back(sampled);
+
+    // search for the key in the map
+    size_t match_pos = 0;
+    for (size_t j = cur_len - n - m - 1; j > 0; --j) {
+        bool match = true;
+        for (size_t k = 0; k < n; ++k) {
+            if (inp[j + k] != key_tokens[k]) {
+                match = false;
+                break;
+            }
+        }
+        if (match) {
+           match_pos = j;
+           break;
+        }
+    }
+    if (match_pos > 0) {
+        LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
+            cur_len, n, m, key_tokens.size(), sampled, match_pos);
+    }
+
+    if (match_pos == 0) {
+        return;
+    }
+
+    // We have a match, now we look for the statistics of the key.
+    size_t key_offset = map.keys.size(); // offset in the map
+    // We iterate through the std::vector<common_ngram_map_key> map->keys.
+    for (size_t i = 0; i < map.keys.size(); ++i) {
+        bool match = true;
+        for (size_t j = 0; j < n; ++j) {
+            if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
+                match = false;
+                break;
+            }
+        }
+        if (match) {
+            key_offset = i;
+            break;
+        }
+    }
+    if (key_offset == map.keys.size()) {
+        // We create a new key-entry, it will get offset key_offset.
+        common_ngram_map_key new_key;
+        new_key.key_idx = match_pos;
+        new_key.stat_idx = 0;
+        new_key.key_num = 0;
+        for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
+            new_key.values[i].value_num = 0;
+            new_key.values[i].n_accepted = m;
+        }
+        map.keys.push_back(new_key);
+    }
+
+    // our key n-gram:
+    common_ngram_map_key & curr_key = map.keys[key_offset];
+
+    // update number of key hits
+    curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
+            (int) COMMON_NGRAM_MAX_VALUE_COUNT);
+
+    if (map.key_only) {
+        // simple mode:
+        // Fill in the draft with the m tokens following the key.
+        // We work with value values[0] only.
+        int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
+
+        for (int i = 0; i < n_draft_tokens; ++i) {
+            draft.push_back(inp[match_pos + n + i]);
+        }
+
+        LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
+                key_offset, curr_key.key_num, draft.size());
+
+        map.last_draft_created   = false;
+        map.last_draft_key_idx   = key_offset;
+        map.last_draft_value_idx = 0; // value 0 is used for simple mode
+        return;
+    }
+
+    if (curr_key.key_num < map.min_hits) {
+        // not enough hits to consider this a good draft
+        LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
+                key_offset, curr_key.key_num, map.min_hits);
+        return;
+    }
+
+    // complex mode: examine the different m-grams after this key n-gram.
+    //
+
+    // determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
+    for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
+        // begins the key n-gram at index i?
+        bool match_key = true;
+        for (size_t k = 0; k < n; ++k) {
+            if (inp[i + k] != key_tokens[k]) {
+                match_key = false;
+                break;
+            }
+        }
+        if (!match_key) {
+            continue;
+        }
+
+        // Do we haven a existing value m-gram or a new one after the key at index i?
+        size_t idx_begin_value_key = i + n;
+        int idx_value = -1;
+        for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
+            size_t idx_begin_value_v = curr_key.values[v].value_idx;
+            if (idx_begin_value_v == 0) {
+                // We found an empty value slot => we found a new value m-gram after the key n-gram.
+                curr_key.values[v].value_idx = idx_begin_value_key;
+                curr_key.values[v].value_num = 0;
+                curr_key.values[v].n_accepted = m;
+                idx_value = v;
+                break;
+            }
+            bool match = true;
+            for (size_t j = 0; j < m; ++j) {
+                if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
+                    match = false;
+                    break;
+                }
+            }
+            if (match) {
+                // We found an existing value m-gram after the key n-gram.
+                idx_value = v;
+                break;
+            }
+        }
+        if (idx_value >= 0) {
+            // We found a value m-gram of the key n-gram.
+            curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
+                    (int) COMMON_NGRAM_MAX_VALUE_COUNT);
+        }
+    }
+    // the statistics are updated up to match_pos.
+    curr_key.stat_idx = match_pos;
+
+    // Do we have a value we could use for the draft?
+    uint16_t max_occur = 0;
+    int slot_max = 0;
+    for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
+        uint16_t curr_occur = curr_key.values[v].value_num;
+        if (curr_occur > max_occur) {
+            max_occur = curr_occur;
+            slot_max = v;
+        }
+    }
+    // What is sum of the other occurences?
+    uint32_t sum_occur = 0;
+    for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
+        if (v == slot_max) {
+            continue;
+        }
+        uint16_t curr_occur = curr_key.values[v].value_num;
+        sum_occur += curr_occur;
+    }
+
+    LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
+            key_offset,
+            max_occur, sum_occur, slot_max,
+            curr_key.values[0].value_idx, curr_key.values[0].value_num,
+            curr_key.values[1].value_idx, curr_key.values[1].value_num,
+            curr_key.values[2].value_idx, curr_key.values[2].value_num,
+            curr_key.values[3].value_idx, curr_key.values[3].value_num
+        );
+    // Print the tokens of the four values (if idx != 0), use LOG_INF
+    for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
+        if (curr_key.values[v].value_idx != 0) {
+            LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
+        }
+    }
+
+    if (sum_occur > 0 && max_occur < 3 * sum_occur) {
+        // The most frequent value is not much more frequent than the other values.
+        // We do not use the draft.
+        return;
+    }
+
+    // We use the most frequent value values[slot_max] for the draft.
+    // Fill in the draft with the m tokens following the key.
+    int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
+
+    for (int i = 0; i < n_draft_tokens; ++i) {
+        draft.push_back(inp[match_pos + n + i]);
+    }
+
+    LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
+            key_offset, slot_max,
+            curr_key.key_num, draft.size());
+
+    map.last_draft_created   = true;
+    map.last_draft_key_idx   = key_offset;
+    map.last_draft_value_idx = slot_max; // value used for draft generation.
+}
+
+void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
+    if (!map.last_draft_created) {
+        return;
+    }
+
+    // find the key and its chosen value.
+    const size_t key_idx = map.last_draft_key_idx;
+    const size_t val_idx = map.last_draft_value_idx;
+
+    // find key corresponding to key_idx.
+    common_ngram_map_key & curr_key = map.keys[key_idx];
+    // find value corresponding to val_idx.
+    struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
+
+    // update the value statistics
+    LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
+            n_accepted, curr_value.n_accepted);
+    curr_value.n_accepted = n_accepted;
+}
+
+// Helper functions.
+//
+
+// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
+std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
+    std::ostringstream oss;
+    oss << '[';
+    for (size_t i = 0; i < length; ++i) {
+        if (i > 0) {
+            oss << ", ";
+        }
+        oss << inp[start + i];
+    }
+    oss << ']';
+    return oss.str();
+}
+
diff --git a/common/ngram-map.h b/common/ngram-map.h
new file mode 100644 (file)
index 0000000..bf91883
--- /dev/null
@@ -0,0 +1,105 @@
+#pragma once
+//
+// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
+//
+// These structures are used to do a lookup of n-grams followed by m-grams in token history.
+//
+// There are two algorithms implemented:
+// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
+// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
+//    The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
+//
+
+#include "llama.h"
+
+#include <vector>
+
+// n-gram simple
+//
+
+// config of n-gram simple.
+struct common_ngram_simple_config {
+    uint16_t   size_ngram;      // size of n-grams to lookup in self-mode
+    uint16_t   size_mgram;      // size of m-grams to draft in self-mode
+    uint16_t   check_rate;      // check for speculative decoding without draft model for each check_rate token
+};
+
+// current state (and config) of n-gram simple.
+struct common_ngram_simple_state {
+    common_ngram_simple_config config;
+
+    size_t idx_last_check = 0; // index of last check in context history (mutable)
+
+    common_ngram_simple_state(const common_ngram_simple_config & config)
+        : config(config) {}
+};
+
+// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
+// state:              the ngram simple state to search in.
+// inp:                the tokens generated so far.
+// sampled:            the token that was just sampled.
+// draft:              vector to store the draft tokens, initially empty.
+llama_tokens common_ngram_simple_draft(
+        common_ngram_simple_state & state,
+        const llama_tokens & tokens, llama_token sampled);
+
+
+// n-gram map
+//
+
+// maximum number of m-gram values stored for each key n-gram.
+#define COMMON_NGRAM_MAX_VALUES 4
+
+// statistics of a m-gram after a known n-gram
+struct common_ngram_map_value {
+    size_t   value_idx = 0;  // index of value m-gram in token-history (0 if unused)
+    uint16_t value_num = 0;  // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
+    int16_t n_accepted = -1;  // number of accepted tokens at last draft (-1 if unused)
+};
+
+// statistics of a n-gram
+struct common_ngram_map_key {
+    size_t   key_idx;   // index of key n-gram in token-history
+    size_t   stat_idx;  // index of last token of stastistics computation (key_num, values)
+
+    uint16_t key_num;   // number of occurences of this key n-gram in token-history
+    common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
+};
+
+// map from n-grams to following m-grams in token-history
+struct common_ngram_map {
+    uint16_t size_key;   // size of key n-grams
+    uint16_t size_value; // size of value m-grams
+
+    bool key_only;       // true if only key n-grams are used, no values.
+
+    // first draft: vector only, no map.
+    std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
+    uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
+    uint16_t min_hits;   // minimum number of key hits to consider a draft
+
+    common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
+                     uint16_t check_rate, uint16_t min_hits)
+        : size_key(sz_key), size_value(sz_value), key_only(only_keys),
+          check_rate(check_rate), min_hits(min_hits) {}
+
+    bool     last_draft_created   = false; // true if a draft was created at last call.
+    size_t   last_draft_key_idx   = 0; // index of last key used for draft generation.
+    uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
+
+    size_t   idx_last_check       = 0; // index of last check in context history
+};
+
+
+// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
+// map:                the ngram map to search in.
+// inp:                the tokens generated so far.
+// sampled:            the token that was just sampled.
+// draft:              vector to store the draft tokens, initially empty.
+void common_ngram_map_draft(
+    common_ngram_map & map,
+    const llama_tokens & inp, llama_token sampled,
+    llama_tokens & draft);
+
+// Update the statistics of a value after a draft was processed.
+void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
index 3e83b0964c8550a68699f6a65d3675c850bcff2e..3f314b5d57865d43ddb5eb9de545cb0f637fd293 100644 (file)
@@ -1,99 +1,54 @@
 #include "speculative.h"
 
+#include "common.h"
 #include "ggml.h"
 #include "llama.h"
 #include "log.h"
-#include "common.h"
+#include "ngram-cache.h"
+#include "ngram-map.h"
 #include "sampling.h"
 
-#include <cstring>
 #include <algorithm>
+#include <cstring>
+#include <iomanip>
 #include <map>
 
 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
 
-struct common_speculative {
-    struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
-    struct llama_context * ctx_dft;
-    struct common_sampler * smpl;
-
-    llama_batch batch;
-    llama_tokens prompt_dft;
-    bool vocab_dft_compatible = true; // whether retokenization is needed
-    std::map<std::string, std::string> tgt_dft_replacements = {};
+const std::vector<enum common_speculative_type> common_speculative_types = {
+    COMMON_SPECULATIVE_TYPE_NONE,
+    COMMON_SPECULATIVE_TYPE_DRAFT,
+    COMMON_SPECULATIVE_TYPE_EAGLE3,
+    COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
+    COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
+    COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
+    COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
 };
 
-struct common_speculative * common_speculative_init(
-        struct llama_context * ctx_tgt,
-        struct llama_context * ctx_dft) {
-    auto * result = new common_speculative {
-        /* .ctx_tgt    = */ ctx_tgt,
-        /* .ctx_dft    = */ ctx_dft,
-        /* .smpl       = */ nullptr,
-        /* .batch      = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
-        /* .prompt_dft = */ {},
-        /* .vocab_dft_compatible = */ false,
-    };
-
-    // TODO: optimize or pass from outside?
-#if 0
-    {
-        common_params_sampling params;
-        params.no_perf = false;
-
-        params.top_k = 40;
-        params.top_p = 0.9;
-
-        params.samplers = {
-            COMMON_SAMPLER_TYPE_TOP_K,
-            COMMON_SAMPLER_TYPE_TOP_P,
-            COMMON_SAMPLER_TYPE_INFILL,
-        };
-
-        result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
-    }
-#else
-    {
-        common_params_sampling params;
-        params.no_perf = false;
-
-        params.top_k = 10;
-
-        params.samplers = {
-            COMMON_SAMPLER_TYPE_TOP_K,
-        };
-
-        result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
-    }
-#endif
-
-    result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
-    LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
-
-    return result;
-}
-
-void common_speculative_free(struct common_speculative * spec) {
-    if (spec == nullptr) {
-        return;
-    }
-
-    common_sampler_free(spec->smpl);
-
-    llama_batch_free(spec->batch);
+const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
+    {"none",          COMMON_SPECULATIVE_TYPE_NONE},
+    {"draft",         COMMON_SPECULATIVE_TYPE_DRAFT},
+    {"eagle3",        COMMON_SPECULATIVE_TYPE_EAGLE3},
+    {"ngram_simple",  COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
+    {"ngram_map_k",   COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
+    {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
+    {"ngram_cache",   COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
+};
 
-    delete spec;
-}
+struct common_speculative_config {
+    common_speculative_type type;
+    common_params_speculative params;
 
-bool common_speculative_are_compatible(
-    const struct llama_context * ctx_tgt,
-    const struct llama_context * ctx_dft) {
-    const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
-    const struct llama_model * model_dft = llama_get_model(ctx_dft);
+    common_speculative_config(common_speculative_type t,
+            const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {}
+};
 
-    const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
-    const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
+static bool common_speculative_are_compatible(
+    const llama_model * model_tgt,
+    const llama_model * model_dft) {
+    const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
+    const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
 
     const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
     LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
@@ -134,11 +89,12 @@ bool common_speculative_are_compatible(
         for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
             const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
             const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
+
             if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
                 LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
                 LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
-                        common_token_to_piece(ctx_tgt, i).c_str(),
-                        common_token_to_piece(ctx_dft, i).c_str());
+                        common_token_to_piece(vocab_tgt, i).c_str(),
+                        common_token_to_piece(vocab_dft, i).c_str());
                 return false;
             }
         }
@@ -147,215 +103,779 @@ bool common_speculative_are_compatible(
     return true;
 }
 
-void common_speculative_add_replacement_tgt_dft(
-        struct common_speculative * spec,
-        const char *source, const char *dest) {
-    spec->tgt_dft_replacements[source] = dest;
-}
+// state of an implementation of speculative decoding
+//
+// each implementation has a unique type and a state that is implementation-specific
+// in a subclass of common_speculative_state
+struct common_speculative_state {
+    const enum common_speculative_type type;
+
+    size_t drafts_call_count       = 0; // number of times this implementation was called.
+    size_t drafts_generated_count  = 0; // number of times a draft or part was generated by this implementation.
+    size_t drafts_accepted_count   = 0; // number of times a draft or part was accepted by the target model.
+    size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation.
+    size_t drafts_accepted_tokens  = 0; // number of tokens accepted by the target model.
+
+    // TODO: track performance of most recent calls
+    const bool gen_perf = true; // whether to generate performance stats.
+
+    int64_t gen_duration_us = 0; // total time spent in this implementation in microseconds.
+
+    common_speculative_state(enum common_speculative_type type) : type(type) {}
+
+    virtual ~common_speculative_state() = default;
+
+    virtual void begin(const llama_tokens & prompt) = 0;
+
+    virtual void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & result) = 0;
+
+    virtual void accept(uint16_t n_accepted) = 0;
+};
+
+struct common_speculative_state_draft : public common_speculative_state {
+    llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
+    llama_context * ctx_dft;
+
+    common_sampler * smpl;
+
+    llama_batch  batch;
+    llama_tokens prompt_dft;
 
-static std::string replace_to_dft(
-        struct common_speculative * spec,
-        const std::string& input) {
-    std::string result = input;
-    for (const auto & pair : spec->tgt_dft_replacements) {
-        size_t pos = result.find(pair.first);
-        while (pos != std::string::npos) {
-            result.replace(pos, pair.first.length(), pair.second);
-            pos = result.find(pair.first, pos + pair.second.length());
+    bool vocab_cmpt = true; // whether retokenization is needed
+    std::unordered_map<std::string, std::string> vocab_map;
+
+    common_speculative_state_draft(
+            enum common_speculative_type type,
+            llama_context * ctx_tgt,
+            llama_context * ctx_dft,
+            const std::vector<std::pair<std::string, std::string>> & replacements)
+        : common_speculative_state(type)
+        , ctx_tgt(ctx_tgt)
+        , ctx_dft(ctx_dft)
+    {
+        batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
+        smpl = nullptr;
+
+        // TODO: optimize or pass from outside?
+        // {
+        //     common_params_sampling params;
+        //     params.no_perf = false;
+        //
+        //     params.top_k = 40;
+        //     params.top_p = 0.9;
+        //
+        //     params.samplers = {
+        //         COMMON_SAMPLER_TYPE_TOP_K,
+        //         COMMON_SAMPLER_TYPE_TOP_P,
+        //         COMMON_SAMPLER_TYPE_INFILL,
+        //     };
+        //
+        //     result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
+        // }
+        {
+            common_params_sampling params;
+            params.no_perf = false;
+            params.top_k = 10;
+            params.samplers = {
+                COMMON_SAMPLER_TYPE_TOP_K,
+            };
+
+            smpl = common_sampler_init(llama_get_model(ctx_dft), params);
         }
-    }
-    return result;
-}
 
-static std::string replace_to_tgt(
-        struct common_speculative * spec,
-        const std::string& input) {
-    std::string result = input;
-    for (const auto& pair : spec->tgt_dft_replacements) {
-        size_t pos = result.find(pair.second);
-        while (pos != std::string::npos) {
-            result.replace(pos, pair.second.length(), pair.first);
-            pos = result.find(pair.second, pos + pair.first.length());
+        vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
+        LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
+
+        if (!vocab_cmpt) {
+            LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
+
+            for (const auto & pair : replacements) {
+                vocab_map[pair.first] = pair.second;
+            }
         }
     }
-    return result;
-}
 
+    ~common_speculative_state_draft() override {
+        llama_perf_context_print(ctx_dft);
 
-llama_tokens common_speculative_gen_draft(
-        struct common_speculative * spec,
-        struct common_speculative_params params,
-        const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
-        llama_token id_last) {
-    auto & batch  = spec->batch;
-    auto & ctx_tgt = spec->ctx_tgt;
-    auto & ctx_dft = spec->ctx_dft;
-    auto & smpl   = spec->smpl;
-    auto & prompt_dft = spec->prompt_dft;
-
-    auto * mem_dft = llama_get_memory(ctx_dft);
-
-    int reuse_i = 0;
-    int reuse_n = 0;
-
-    const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
-
-    llama_tokens prompt_tgt_draft_model;
-    if (!spec->vocab_dft_compatible) {
-        std::string text;
-        text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
-        text = replace_to_dft(spec, text);
-        LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
-        prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
-
-        // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
-        const auto * model_tgt = llama_get_model(ctx_tgt);
-        const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
-
-        int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
-        GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
-        text.resize(-n_chars);
-        llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
-        text = replace_to_dft(spec, text);
-
-        LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
-        id_last = common_tokenize(ctx_dft, text, false, true)[0];
-    }
-    // prompt_tgt's tokens will always be compatible with ctx_dft
-    const llama_tokens &prompt_tgt =
-        spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
-
-    const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
-
-    // reuse as much as possible from the old draft context
-    // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
-    for (int i = 0; i < (int) prompt_dft.size(); ++i) {
-        int cur = 0;
-        while (i_start + cur < (int) prompt_tgt.size() &&
-               i       + cur < (int) prompt_dft.size() &&
-               prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
-            cur++;
+        llama_free(ctx_dft);
+
+        common_sampler_free(smpl);
+
+        llama_batch_free(batch);
+    }
+
+    void begin(const llama_tokens & prompt) override {
+        GGML_UNUSED(prompt);
+    }
+
+    void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & result) override {
+        auto * spec = this;
+
+        auto & batch      = spec->batch;
+        auto & ctx_tgt    = spec->ctx_tgt;
+        auto & ctx_dft    = spec->ctx_dft;
+        auto & smpl       = spec->smpl;
+        auto & prompt_dft = spec->prompt_dft;
+
+        auto * mem_dft = llama_get_memory(ctx_dft);
+
+        int reuse_i = 0;
+        int reuse_n = 0;
+
+        const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
+
+        llama_tokens prompt_cnv;
+        if (!spec->vocab_cmpt) {
+            std::string text;
+
+            text = common_detokenize(ctx_tgt, prompt_tgt, true);
+            text = replace_to_dft(text);
+
+            LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
+
+            prompt_cnv = common_tokenize(ctx_dft, text, false, true);
+
+            // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
+            const auto * model_tgt = llama_get_model(ctx_tgt);
+            const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
+
+            int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
+            GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
+
+            text.resize(-n_chars);
+            llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
+            text = replace_to_dft(text);
+
+            LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
+            id_last = common_tokenize(ctx_dft, text, false, true)[0];
         }
 
-        if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
-            reuse_i = i;
-            reuse_n = cur;
+        const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
+
+        const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
+
+        // reuse as much as possible from the old draft context
+        // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
+        for (int i = 0; i < (int) prompt_dft.size(); ++i) {
+            int cur = 0;
+            while (i_start + cur < (int) prompt_cur.size() &&
+                    i       + cur < (int) prompt_dft.size() &&
+                    prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
+                cur++;
+            }
+
+            if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
+                reuse_i = i;
+                reuse_n = cur;
+            }
         }
-    }
 
-    LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
+        LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
 
-    llama_tokens result;
-    result.reserve(params.n_draft);
-
-    if (reuse_n == 0) {
-        llama_memory_clear(mem_dft, false);
-        prompt_dft.clear();
-    } else {
-        // this happens when a previous draft has been discarded (for example, due to being too small), but the
-        // target model agreed with it. in this case, we simply pass back the previous results to save compute
-        if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
-            for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
-                result.push_back(prompt_dft[i]);
-
-                if (params.n_draft <= (int) result.size()) {
-                    break;
+        result.clear();
+        result.reserve(params.n_max);
+
+        if (reuse_n == 0) {
+            llama_memory_clear(mem_dft, false);
+            prompt_dft.clear();
+        } else {
+            // this happens when a previous draft has been discarded (for example, due to being too small), but the
+            // target model agreed with it. in this case, we simply pass back the previous results to save compute
+            if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
+                for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
+                    result.push_back(prompt_dft[i]);
+
+                    if (params.n_max <= (int) result.size()) {
+                        break;
+                    }
                 }
+
+                return;
+            }
+
+            if (reuse_i > 0) {
+                llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
+                llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
+
+                prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
             }
 
-            return result;
+            if (reuse_n < (int) prompt_dft.size()) {
+                llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
+                prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
+            }
         }
 
-        if (reuse_i > 0) {
-            llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
-            llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
+        // prepare a batch to evaluate any new tokens in the prompt
+        common_batch_clear(batch);
+
+        for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
+            //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
+            common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
 
-            prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
+            prompt_dft.push_back(prompt_cur[i]);
         }
 
-        if (reuse_n < (int) prompt_dft.size()) {
-            llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
-            prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
+        // we should rarely end-up here during normal decoding
+        if (batch.n_tokens > 0) {
+            //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
+
+            llama_decode(ctx_dft, batch);
+        }
+
+        const llama_pos n_past = prompt_dft.size();
+
+        LOG_DBG("%s: n_past = %d\n", __func__, n_past);
+
+        common_batch_clear(batch);
+        common_batch_add  (batch, id_last, n_past, { 0 }, true);
+
+        prompt_dft.push_back(id_last);
+
+        LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
+
+        llama_decode(ctx_dft, batch);
+
+        common_sampler_reset(smpl);
+
+        // sample n_draft tokens from the draft model
+        for (int i = 0; i < params.n_max; ++i) {
+            common_batch_clear(batch);
+
+            common_sampler_sample(smpl, ctx_dft, 0, true);
+
+            const auto * cur_p = common_sampler_get_candidates(smpl, true);
+
+            for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
+                LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
+                        k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
+            }
+
+            // add drafted token for each sequence
+            const llama_token id = cur_p->data[0].id;
+
+            common_sampler_accept(smpl, id, true);
+
+            result.push_back(id);
+
+            if (params.n_max <= (int) result.size()) {
+                break;
+            }
+
+            // only collect very high-confidence draft tokens
+            if (cur_p->data[0].p < params.p_min) {
+                break;
+            }
+
+            common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
+
+            // evaluate the drafted tokens on the draft model
+            llama_decode(ctx_dft, batch);
+
+            prompt_dft.push_back(id);
+        }
+
+        if (!spec->vocab_cmpt) {
+            std::string detokenized = common_detokenize(ctx_dft, result, true);
+            detokenized = replace_to_tgt(detokenized);
+            LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
+            result = common_tokenize(ctx_tgt, detokenized, false, true);
+            if (result.size() > (size_t)params.n_max) {
+                result.resize(params.n_max);
+            }
         }
     }
 
-    // prepare a batch to evaluate any new tokens in the prompt
-    common_batch_clear(batch);
+    void accept(uint16_t n_accepted) override {
+        // noop
+        GGML_UNUSED(n_accepted);
+    }
 
-    for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
-        //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
-        common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
+    std::string replace_to_dft(const std::string & input) const {
+        std::string result = input;
 
-        prompt_dft.push_back(prompt_tgt[i]);
+        for (const auto & pair : this->vocab_map) {
+            size_t pos = result.find(pair.first);
+            while (pos != std::string::npos) {
+                result.replace(pos, pair.first.length(), pair.second);
+                pos = result.find(pair.first, pos + pair.second.length());
+            }
+        }
+
+        return result;
     }
 
-    // we should rarely end-up here during normal decoding
-    if (batch.n_tokens > 0) {
-        //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
+    std::string replace_to_tgt(const std::string & input) const {
+        std::string result = input;
 
-        llama_decode(ctx_dft, batch);
+        for (const auto & pair : this->vocab_map) {
+            size_t pos = result.find(pair.second);
+            while (pos != std::string::npos) {
+                result.replace(pos, pair.second.length(), pair.first);
+                pos = result.find(pair.second, pos + pair.first.length());
+            }
+        }
+
+        return result;
+    }
+};
+
+struct common_speculative_state_eagle3 : public common_speculative_state {
+    common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {}
+
+    void begin(const llama_tokens & prompt) override {
+        GGML_UNUSED(prompt);
+    }
+
+    void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & draft_tokens) override {
+        // TODO: implement
+        GGML_UNUSED(params);
+        GGML_UNUSED(prompt_tgt);
+        GGML_UNUSED(id_last);
+        GGML_UNUSED(draft_tokens);
     }
 
-    const llama_pos n_past = prompt_dft.size();
+    void accept(uint16_t n_accepted) override {
+        // noop
+        GGML_UNUSED(n_accepted);
+    }
+};
 
-    LOG_DBG("%s: n_past = %d\n", __func__, n_past);
+// state of self-speculation (simple implementation, not ngram-map)
+struct common_speculative_state_ngram_simple : public common_speculative_state {
+    common_ngram_simple_state state;
 
-    common_batch_clear(batch);
-    common_batch_add  (batch, id_last, n_past, { 0 }, true);
+    common_speculative_state_ngram_simple(
+            enum common_speculative_type type,
+            common_ngram_simple_state state)
+        : common_speculative_state(type), state(state) {}
 
-    prompt_dft.push_back(id_last);
+    void begin(const llama_tokens & prompt) override {
+        GGML_UNUSED(prompt);
+    }
 
-    LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
+    void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & result) override {
+        result = common_ngram_simple_draft(state, prompt_tgt, id_last);
+        GGML_UNUSED(params);
+    }
 
-    llama_decode(ctx_dft, batch);
+    void accept(uint16_t n_accepted) override {
+        // noop
+        GGML_UNUSED(n_accepted);
+    }
+};
 
-    common_sampler_reset(smpl);
+struct common_speculative_state_ngram_map_k : public common_speculative_state {
+    // draft ngram map for speculative decoding without draft model
+    common_ngram_map map;
 
-    // sample n_draft tokens from the draft model
-    for (int i = 0; i < params.n_draft; ++i) {
-        common_batch_clear(batch);
+    common_speculative_state_ngram_map_k(
+            enum common_speculative_type type,
+            common_ngram_map map)
+        : common_speculative_state(type), map(std::move(map)) {}
 
-        common_sampler_sample(smpl, ctx_dft, 0, true);
+    void begin(const llama_tokens & prompt) override {
+        GGML_UNUSED(prompt);
+    }
+
+    void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & result) override {
+        common_ngram_map_draft(map, prompt_tgt, id_last, result);
+        GGML_UNUSED(params);
+    }
 
-        const auto * cur_p = common_sampler_get_candidates(smpl, true);
+    void accept(uint16_t n_accepted) override {
+        common_ngram_map_accept(map, n_accepted);
+    }
+};
 
-        for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
-            LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
-                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
+struct common_speculative_state_ngram_cache : public common_speculative_state {
+    uint16_t n_draft;
+    bool save_dynamic;
+    bool save_static;
+
+    common_ngram_cache ngram_cache_context;
+    common_ngram_cache ngram_cache_dynamic;
+    common_ngram_cache ngram_cache_static;
+
+    size_t cache_size = 0; // number of tokens in n-gram cache
+
+    common_speculative_state_ngram_cache(
+            const enum common_speculative_type type,
+            const std::string & path_static,
+            const std::string & path_dynamic,
+            uint16_t            n_draft,
+            bool                save_dynamic,
+            bool                save_static)
+        : common_speculative_state(type)
+        , n_draft(n_draft)
+        , save_dynamic(save_dynamic)
+        , save_static(save_static)
+    {
+        if (!path_static.empty()) {
+            try {
+                ngram_cache_static = common_ngram_cache_load(path_static);
+            } catch (...) {
+                LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
+                GGML_ABORT("Couldn't read static lookup cache");
+            }
         }
 
-        // add drafted token for each sequence
-        const llama_token id = cur_p->data[0].id;
+        if (!path_dynamic.empty()) {
+            try {
+                ngram_cache_dynamic = common_ngram_cache_load(path_dynamic);
+            } catch (...) {
+                LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
+                GGML_ABORT("Couldn't read dynamic lookup cache");
+            }
+        }
+    }
 
-        common_sampler_accept(smpl, id, true);
+    void begin(const llama_tokens & prompt) override {
+        GGML_UNUSED(prompt);
+    }
 
-        result.push_back(id);
+    void draft(
+            const common_params_speculative & params,
+            const llama_tokens & prompt_tgt,
+            llama_token id_last,
+            llama_tokens & result) override {
+        GGML_UNUSED(params);
+
+        if (cache_size < prompt_tgt.size() + 1) {
+            llama_tokens tokens_new;
+            tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
+            for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
+                tokens_new.push_back(prompt_tgt[j]);
+            }
+            tokens_new.push_back(id_last); // add the last token
 
-        if (params.n_draft <= (int) result.size()) {
-            break;
+            // Update context ngram cache with new prompt_tgt:
+            common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
+                    tokens_new, tokens_new.size(), false);
+            cache_size = prompt_tgt.size() + 1;
         }
 
-        // only collect very high-confidence draft tokens
-        if (cur_p->data[0].p < params.p_min) {
-            break;
+        llama_tokens inp;
+        inp.reserve(prompt_tgt.size() + 1);
+        for (size_t j = 0; j < prompt_tgt.size(); ++j) {
+            inp.push_back(prompt_tgt[j]);
         }
+        inp.push_back(id_last);
 
-        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
+        result.push_back(id_last);
 
-        // evaluate the drafted tokens on the draft model
-        llama_decode(ctx_dft, batch);
+        common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
+                ngram_cache_context,
+                ngram_cache_dynamic,
+                ngram_cache_static);
+
+        if (result.size() > 0) {
+            // delete first token in result (which is the id_last token)
+            result.erase(result.begin());
+        }
+    }
 
-        prompt_dft.push_back(id);
+    void accept(uint16_t n_accepted) override {
+        // TODO: noop
+        GGML_UNUSED(n_accepted);
     }
+};
+
+struct common_speculative {
+    std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
+    common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
+};
+
+static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
+    uint16_t size_key   = config.params.ngram_size_n;
+    uint16_t size_value = config.params.ngram_size_m;
+    bool     key_only   = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
+    uint16_t check_rate = config.params.ngram_check_rate;
+    uint16_t min_hits   = config.params.ngram_min_hits;
+
+    return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
+}
+
+static common_speculative_state_ngram_cache create_state_ngram_cache(
+        const std::string & path_static, const std::string & path_dynamic,
+        const common_speculative_config & config) {
+    uint16_t n_draft = 8; // TODO get from config?
 
-    if (!spec->vocab_dft_compatible) {
-        std::string detokenized = common_detokenize(ctx_dft, result, true);
-        detokenized = replace_to_tgt(spec, detokenized);
-        LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
-        result = common_tokenize(ctx_tgt, detokenized, false, true);
-        if (result.size() > (size_t)params.n_draft) {
-            result.resize(params.n_draft);
+    // TODO bool param in common/common.h to set save_static/save_dynamic?
+    bool save_static = false;
+    bool save_dynamic = false;
+
+    common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
+
+    return state;
+}
+
+std::string common_speculative_type_name_str() {
+    std::string result;
+    for (size_t i = 0; i < common_speculative_types.size(); i++) {
+        if (i > 0) {
+            result += ", ";
         }
+        result += common_speculative_type_to_str(common_speculative_types[i]);
     }
     return result;
 }
+
+std::string common_speculative_type_to_str(enum common_speculative_type type) {
+    switch (type) {
+        case COMMON_SPECULATIVE_TYPE_NONE:          return "none";
+        case COMMON_SPECULATIVE_TYPE_DRAFT:         return "draft";
+        case COMMON_SPECULATIVE_TYPE_EAGLE3:        return "eagle3";
+        case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:  return "ngram_simple";
+        case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:   return "ngram_map_k";
+        case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
+        case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:   return "ngram_cache";
+        default:                                    return "unknown";
+    }
+}
+
+enum common_speculative_type common_speculative_type_from_name(const std::string & name) {
+    const auto it = common_speculative_type_from_name_map.find(name);
+    if (it == common_speculative_type_from_name_map.end()) {
+        return COMMON_SPECULATIVE_TYPE_COUNT;
+    }
+    return it->second;
+}
+
+// initialization of the speculative decoding system
+//
+common_speculative * common_speculative_init(
+        const common_params_speculative & params,
+              llama_context             * ctx_tgt) {
+    llama_context * ctx_dft = nullptr;
+    if (params.model_dft) {
+        ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
+        if (ctx_dft == nullptr) {
+            LOG_ERR("%s", "failed to create draft context\n");
+            return nullptr;
+        }
+    }
+
+    // Compute the implementations to use based on the config and their order of preference
+    std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
+    {
+        bool has_draft = !params.mparams_dft.path.empty();
+        bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
+
+        bool has_ngram_cache   = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
+        bool has_ngram_simple  = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
+        bool has_ngram_map_k   = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
+        bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
+
+        // In a more complex implementation we could use the same implementation but with different parameters.
+        // This was initially used in PR-18471 but removed to simplify the code.
+        if (has_ngram_simple) {
+            // This implementation can guess a lot of tokens without any draft model.
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params));
+        }
+        if (has_ngram_map_k) {
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params));
+        }
+        if (has_ngram_map_k4v) {
+            // This implementation can guess tokens with high acceptance rate but is more expensive.
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
+        }
+        if (has_ngram_cache) {
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
+        }
+        if (has_draft) {
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
+        }
+        if (has_draft_eagle3) {
+            configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
+        }
+    }
+
+    std::vector<std::unique_ptr<common_speculative_state>> impls = {};
+
+    for (const common_speculative_config & config : configs) {
+        LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str());
+        switch (config.type) {
+            case COMMON_SPECULATIVE_TYPE_NONE:
+                break;
+            case COMMON_SPECULATIVE_TYPE_DRAFT: {
+                impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
+                    /* .ctx_tgt      = */ ctx_tgt,
+                    /* .ctx_dft      = */ ctx_dft,
+                    /* .replacements = */ params.replacements
+                ));
+                break;
+            }
+            case COMMON_SPECULATIVE_TYPE_EAGLE3: {
+                impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
+                break;
+            }
+            case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
+                common_ngram_map ngram_map = get_common_ngram_map(config);
+
+                uint16_t ngram_size_key   = ngram_map.size_key;
+                uint16_t mgram_size_value = ngram_map.size_value;
+                uint16_t check_rate       = ngram_map.check_rate;
+
+                auto config_simple = common_ngram_simple_config{
+                    /* .size_ngram      = */ ngram_size_key,
+                    /* .size_mgram      = */ mgram_size_value,
+                    /* .check_rate      = */ check_rate
+                };
+                auto state = std::make_unique<common_speculative_state_ngram_simple>(
+                    /* .type            = */ config.type,
+                    /* .state           = */ common_ngram_simple_state(config_simple)
+                );
+                impls.push_back(std::move(state));
+                break;
+            }
+            case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
+            case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
+                impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
+                    (config.type),
+                    get_common_ngram_map(config)
+                ));
+                break;
+            }
+            case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
+                auto state = create_state_ngram_cache(
+                        params.lookup_cache_static, params.lookup_cache_dynamic, config);
+                impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
+                break;
+            }
+            default:
+                break;
+        }
+    }
+
+    if (impls.empty()) {
+        LOG_WRN("%s", "no implementations specified for speculative decoding\n");
+        return nullptr;
+    }
+
+    auto * result = new common_speculative {
+        /* .impls = */ std::move(impls)
+    };
+
+    return result;
+}
+
+void common_speculative_free(common_speculative * spec) {
+    if (spec == nullptr) {
+        return;
+    }
+
+    delete spec;
+}
+
+void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) {
+    if (spec == nullptr) {
+        return;
+    }
+
+    for (auto & impl : spec->impls) {
+        impl->begin(prompt);
+    }
+}
+
+llama_tokens common_speculative_draft(
+        common_speculative * spec,
+        const common_params_speculative & params,
+        const llama_tokens & prompt_tgt, // specified in target model vocab
+        llama_token id_last) {
+    llama_tokens result;
+
+    spec->curr_impl = nullptr; // reset current implementation
+
+    for (auto & impl : spec->impls) {
+        {
+            const int64_t t_start_us = impl->gen_perf ? ggml_time_us() : 0;
+
+            impl->draft(params, prompt_tgt, id_last, result);
+
+            const int64_t t_now_us = impl->gen_perf ? ggml_time_us() : 0;
+
+            impl->drafts_call_count++;
+            impl->gen_duration_us += t_now_us - t_start_us; // accumulate duration for this implementation
+        }
+
+        if (!result.empty()) {
+            LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
+                    common_speculative_type_to_str(impl.get()->type).c_str(),
+                    prompt_tgt.size(),
+                    impl.get()->drafts_call_count, result.size());
+
+            spec->curr_impl = impl.get(); // set current implementation for stats
+            impl->drafts_generated_count++;
+            impl->drafts_generated_tokens += result.size();
+
+            break; // We have a draft, so break out of the loop and return it.
+        }
+    }
+
+    return result;
+}
+
+void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
+    if (n_accepted == 0) {
+        return;
+    }
+
+    common_speculative_state * impl = spec->curr_impl;
+
+    GGML_ASSERT(impl);
+
+    if (n_accepted > 0) {
+        impl->drafts_accepted_count++;
+        impl->drafts_accepted_tokens += n_accepted;
+    }
+
+    impl->accept(n_accepted);
+}
+
+void common_speculative_print_stats(const common_speculative * spec) {
+    if (spec == nullptr) {
+        return;
+    }
+
+    for (const auto & impl : spec->impls) {
+        std::string str_perf;
+        if (impl->gen_perf) {
+            std::ostringstream oss;
+            oss << std::fixed << std::setprecision(3) << impl->gen_duration_us / 1000.0;
+            str_perf = ", dur = " + oss.str() + " ms";
+        } else {
+            str_perf = "";
+        }
+
+        LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
+                common_speculative_type_to_str(impl->type).c_str(),
+                impl->drafts_call_count,
+                impl->drafts_generated_count,
+                impl->drafts_accepted_count,
+                impl->drafts_generated_tokens,
+                impl->drafts_accepted_tokens,
+                str_perf.c_str());
+    }
+}
index e69d7aaa1eb00b06669f17d17ced98803f26e2ca..9e1888e4be0b3aec53c3ebf5ccdf5ea7fee0c1b3 100644 (file)
@@ -5,31 +5,33 @@
 
 struct common_speculative;
 
-struct common_speculative_params {
-    int n_draft = 16;  // max drafted tokens
-    int n_reuse = 256;
+// comma separated list of all types
+std::string common_speculative_type_name_str();
 
-    float p_min = 0.75f; // min probability required to accept a token in the draft
-};
+// convert string to type
+enum common_speculative_type common_speculative_type_from_name(const std::string & name);
 
-struct common_speculative * common_speculative_init(
-        struct llama_context * ctx_tgt,
-        struct llama_context * ctx_dft
-);
+// convert type to string
+std::string common_speculative_type_to_str(enum common_speculative_type type);
 
-void common_speculative_free(struct common_speculative * spec);
+common_speculative * common_speculative_init(
+        const common_params_speculative & params,
+              llama_context             * ctx_tgt);
 
-bool common_speculative_are_compatible(
-        const struct llama_context * ctx_tgt,
-        const struct llama_context * ctx_dft);
+void common_speculative_free(common_speculative * spec);
 
-void common_speculative_add_replacement_tgt_dft(
-        struct common_speculative * spec,
-        const char *source, const char *dest);
+// optionally call once at the beginning of a new generation
+void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
 
 // sample up to n_draft tokens and add them to the batch using the draft model
-llama_tokens common_speculative_gen_draft(
-               struct common_speculative * spec,
-        struct common_speculative_params   params,
-                      const llama_tokens & prompt,
-                             llama_token   id_last);
+llama_tokens common_speculative_draft(
+                     common_speculative * spec,
+        const common_params_speculative & params,
+                     const llama_tokens & prompt,
+                            llama_token   id_last);
+
+// informs the speculative decoder that n_accepted tokens were accepted by the target model
+void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
+
+// print statistics about the speculative decoding
+void common_speculative_print_stats(const common_speculative * spec);
diff --git a/docs/speculative.md b/docs/speculative.md
new file mode 100644 (file)
index 0000000..8281eaa
--- /dev/null
@@ -0,0 +1,120 @@
+# Speculative Decoding
+
+llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
+
+[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
+
+## Implementations
+
+The `llama-server` application supports several implementations of speculative decoding:
+
+### Draft Model (`draft`)
+
+A much smaller model (called the _draft model_) generates drafts.
+A draft model is the most used approach in speculative decoding.
+
+### n-gram Cache (`ngram-cache`)
+
+An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
+A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
+
+See:
+
+- #5479, #6828, #6848
+
+### n-gram Map (`ngram-simple`, `ngram-map-*`)
+
+These implementations search the token history for patterns and use matching sequences as draft candidates.
+They require no additional model but rely on patterns that have already appeared in the generated text.
+An example to use this approach can be the rewriting of source code by a LLM.
+
+#### n-gram Map (`ngram-simple`)
+
+This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
+
+#### n-gram Map Key (`ngram-map-k`)
+
+This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
+
+The number of accepted tokens is stored for each used n-gram.
+
+#### n-gram Map Key-4-Values (`ngram-map-k4v`)
+
+This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
+
+The number of accepted tokens is stored for each used n-gram.
+
+**Example:** Server options to be used if there are a lot of longer repetitions.
+```bash
+llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
+```
+
+
+## Command-Line Options
+
+If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
+
+```
+--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
+                                        type of speculative decoding to use when no draft model is provided
+                                        (default: none)
+--spec-ngram-size-n N                   ngram size N for ngram-simple/ngram-map speculative decoding, length
+                                        of lookup n-gram (default: 12)
+--spec-ngram-size-m N                   ngram size M for ngram-simple/ngram-map speculative decoding, length
+                                        of draft m-gram (default: 48)
+--spec-ngram-check-rate N               ngram check rate for ngram-simple/ngram-map speculative decoding
+                                        (default: 1)
+--spec-ngram-min-hits N                 minimum hits for ngram-map speculative decoding (default: 1)
+```
+
+### `--spec-type TYPE`
+
+Specifies a type of speculative decoding without draft model.
+
+| Type | Description |
+|------|-------------|
+| `none` | No speculative decoding (default) |
+| `ngram-cache` | Use n-gram cache lookup |
+| `ngram-simple` | Use simple n-gram pattern matching |
+| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
+| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
+
+**Example:** Server-instance used to refactor source code.
+```bash
+./llama-server [...] --spec-type ngram-simple
+```
+
+### `--spec-ngram-size-n N`
+
+Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
+The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
+
+### `--spec-ngram-size-m M`
+
+Sets the size M of the draft m-gram for n-gram map based speculative decoding.
+The m-gram size determines how many tokens to draft when a match is found.
+Larger values can provide more speedup but may reduce acceptance rate.
+
+### `--spec-ngram-check-rate R`
+
+This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
+
+### `--spec-ngram-min-hits H`
+
+This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
+
+## Statistics
+Each speculative decoding implementation prints statistics.
+
+```
+draft acceptance rate = 0.57576 (  171 accepted /   297 generated)
+statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
+statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
+```
+
+- `#calls`: number of calls of this implementations
+- `#gen drafts`: number of drafts generated by this implementation
+- `#acc drafts`: number of drafts accepted (partially) by the main model
+- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
+- `#acc tokens`: number of tokens accepted by the main model
+
index bb94a8fe06d6a36f4bf0fd672caef49449b20674..f7b6ea1b190b9f1a37ea554498dd17d36a0c9661 100644 (file)
@@ -32,9 +32,9 @@ int main(int argc, char ** argv){
 
     common_ngram_cache ngram_cache;
     common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
-    fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
+    fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
 
-    common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
+    common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
 
     return 0;
 }
index 135f6fcab9515d4ae427ecd5a42bb9a0d6d4bbc5..ae28b2e6e86d8a3c957976f317371180abf52003 100644 (file)
@@ -46,18 +46,18 @@ int main(int argc, char ** argv){
     {
         const int64_t t_start_draft_us = ggml_time_us();
 
-        if (!params.lookup_cache_static.empty()) {
+        if (!params.speculative.lookup_cache_static.empty()) {
             try {
-                ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
+                ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
             } catch (std::ifstream::failure const &) {
-                LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
+                LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
                 exit(1);
             }
         }
 
-        if (!params.lookup_cache_dynamic.empty()) {
+        if (!params.speculative.lookup_cache_dynamic.empty()) {
             try {
-                ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
+                ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
             } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
         }
 
index 27f159940a42317f9a5a46d9473f9236eba328b8..8e73138a5f2421493f82a7ccad165dfd064c5444 100644 (file)
@@ -51,18 +51,18 @@ int main(int argc, char ** argv){
         const int64_t t_start_draft_us = ggml_time_us();
         common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
 
-        if (!params.lookup_cache_static.empty()) {
+        if (!params.speculative.lookup_cache_static.empty()) {
             try {
-                ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
+                ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
             } catch (std::ifstream::failure const &) {
-                LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
+                LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
                 exit(1);
             }
         }
 
-        if (!params.lookup_cache_dynamic.empty()) {
+        if (!params.speculative.lookup_cache_dynamic.empty()) {
             try {
-                ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
+                ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
             } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
         }
 
@@ -210,7 +210,7 @@ int main(int argc, char ** argv){
 
     // Update dynamic ngram cache with context ngram cache and save it to disk:
     common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
-    common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
+    common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
 
     LOG("\n\n");
 
index 8141052a22768892305cb1eab493962ccaff3d83..d8b1f5a480cd09a29e24966a868232c0913fe5fb 100644 (file)
@@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
 
     common_init();
 
-    if (params.speculative.model.path.empty()) {
+    if (params.speculative.mparams_dft.path.empty()) {
         LOG_ERR("%s: --model-draft is required\n", __func__);
         return 1;
     }
@@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
     llama_numa_init(params.numa);
 
     llama_model * model_tgt = NULL;
-    //llama_model * model_dft = NULL;
 
     llama_context * ctx_tgt = NULL;
-    llama_context * ctx_dft = NULL;
 
     // load the target model
     auto llama_init_tgt = common_init_from_params(params);
@@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
     const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
 
     // load the draft model
-    params.devices      = params.speculative.devices;
-    params.model        = params.speculative.model;
-    params.n_ctx        = params.speculative.n_ctx;
-    params.n_batch      = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
-    params.n_gpu_layers = params.speculative.n_gpu_layers;
-
-    if (params.speculative.cpuparams.n_threads > 0) {
-        params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
-    }
+    llama_model_ptr model_dft;
+
+    // TODO: simplify this logic
+    {
+        const auto & params_spec = params.speculative;
 
-    params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
-    params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
+        auto params_dft = params;
 
-    auto llama_init_dft = common_init_from_params(params);
+        params_dft.n_parallel   = 1;
+        params_dft.n_ctx        = params_spec.n_ctx;
+        params_dft.n_batch      = llama_n_ctx_seq(ctx_tgt);
+        params_dft.devices      = params_spec.devices;
+        params_dft.model        = params_spec.mparams_dft;
+        params_dft.n_gpu_layers = params_spec.n_gpu_layers;
+
+        if (params_spec.cpuparams.n_threads > 0) {
+            params_dft.cpuparams.n_threads       = params.speculative.cpuparams.n_threads;
+            params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
+        }
 
-    //model_dft = llama_init_dft->model();
-    ctx_dft   = llama_init_dft->context();
+        params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
+
+        auto mparams_dft = common_model_params_to_llama(params_dft);
+
+        model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
+        if (model_dft == nullptr) {
+            LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
+            return 1;
+        }
 
-    if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
-        LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
+        params.speculative.model_dft = model_dft.get();
+        params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
     }
 
     // Tokenize the prompt
@@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
         LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
     }
 
-    // how many tokens to draft each time
-    int n_draft     = params.speculative.n_max;
-    int n_draft_min = params.speculative.n_min;
-
-    float p_min = params.speculative.p_min;
-
     int n_predict = 0;
     int n_drafted = 0;
     int n_accept  = 0;
@@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
     int n_past = inp.size() - 1;
 
     // init the speculator
-    struct common_speculative_params params_spec;
-    params_spec.n_draft = n_draft;
-    params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
-    params_spec.p_min   = p_min;
-
-    struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
-    for (auto &pair : params.speculative.replacements) {
-        common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
-    }
+    const auto & params_spec = params.speculative;
+
+    struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
+
+    common_speculative_begin(spec, prompt_tgt);
 
     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
 
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
         // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
         // from a cache or lookup tables.
         //
-        llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
+        llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
 
         //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
 
@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
         // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
         {
             // do not waste time on small drafts
-            if (draft.size() < (size_t) n_draft_min) {
+            if (draft.size() < (size_t) params_spec.n_min) {
                 draft.clear();
             }
 
@@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
     LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
 
     LOG_INF("\n");
-    LOG_INF("n_draft   = %d\n", n_draft);
+    LOG_INF("n_draft   = %d\n", params_spec.n_max);
     LOG_INF("n_predict = %d\n", n_predict);
     LOG_INF("n_drafted = %d\n", n_drafted);
     LOG_INF("n_accept  = %d\n", n_accept);
@@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
     LOG_INF("\n");
     LOG_INF("draft:\n\n");
 
-    llama_perf_context_print(ctx_dft);
-
     LOG_INF("\n");
     LOG_INF("target:\n\n");
     common_perf_print(ctx_tgt, smpl);
index 89d3249431e853f6a028264eda56f87e62f9cdc2..3e5cf5f46b55a746092b8c47f37af6fee5e26eae 100644 (file)
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
 
     common_init();
 
-    if (params.speculative.model.path.empty()) {
+    if (params.speculative.mparams_dft.path.empty()) {
         LOG_ERR("%s: --model-draft is required\n", __func__);
         return 1;
     }
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
 
     // load the draft model
     params.devices = params.speculative.devices;
-    params.model = params.speculative.model;
+    params.model = params.speculative.mparams_dft;
     params.n_gpu_layers = params.speculative.n_gpu_layers;
     if (params.speculative.cpuparams.n_threads > 0) {
         params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
index 73cb4c75b3e738053025786d512eb29f80f6b0ae..1ca4e3cc0e9a48c893477ebe2b995c27bc7cb6e6 100644 (file)
@@ -48,11 +48,8 @@ enum server_state {
 struct server_slot {
     int id;
 
-    llama_batch batch_spec = {};
-
     // TODO: change to unique_ptrs for consistency:
     llama_context * ctx = nullptr;
-    llama_context * ctx_dft = nullptr;
 
     // multimodal
     mtmd_context * mctx = nullptr;
@@ -259,7 +256,7 @@ struct server_slot {
     }
 
     bool can_speculate() const {
-        return ctx_dft;
+        return !!spec;
     }
 
     void add_token(const completion_token_output & token) {
@@ -295,6 +292,7 @@ struct server_slot {
             SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
             n_draft_max = 0;
         }
+
         return n_draft_max;
     }
 
@@ -397,6 +395,8 @@ struct server_slot {
                     draft_ratio, n_draft_accepted, n_draft_total
             );
         }
+
+        common_speculative_print_stats(spec);
     }
 
     json to_json(bool only_metrics = false) const {
@@ -553,18 +553,13 @@ private:
 
     // note: keep these alive - they determine the lifetime of the model, context, etc.
     common_init_result_ptr llama_init;
-    common_init_result_ptr llama_init_dft;
 
     llama_context * ctx = nullptr;
 
-    bool vocab_dft_compatible = true;
-
-    llama_model * model_dft = nullptr;
-
-    llama_context_params cparams_dft;
-
     llama_batch batch {};
 
+    llama_model_ptr model_dft;
+
     bool add_bos_token  = true;
 
     int32_t n_ctx; // total context for all clients / slots
@@ -597,13 +592,8 @@ private:
 
         // Clear any sampling context
         for (server_slot & slot : slots) {
-            llama_free(slot.ctx_dft);
-            slot.ctx_dft = nullptr;
-
             common_speculative_free(slot.spec);
             slot.spec = nullptr;
-
-            llama_batch_free(slot.batch_spec);
         }
 
         llama_batch_free(batch);
@@ -648,44 +638,39 @@ private:
 
         add_bos_token = llama_vocab_get_add_bos(vocab);
 
-        if (params_base.has_speculative()) {
-            SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
+        if (params_base.speculative.has_dft()) {
+            SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
+
+            const auto & params_spec = params_base.speculative;
 
             auto params_dft = params_base;
 
-            params_dft.devices      = params_base.speculative.devices;
-            params_dft.model        = params_base.speculative.model;
-            params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
-            params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
             params_dft.n_parallel   = 1;
-            params_dft.cache_type_k = params_base.speculative.cache_type_k;
-            params_dft.cache_type_v = params_base.speculative.cache_type_v;
+            params_dft.n_ctx        = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
+            params_dft.n_batch      = llama_n_ctx_seq(ctx);
+            params_dft.devices      = params_spec.devices;
+            params_dft.model        = params_spec.mparams_dft;
+            params_dft.n_gpu_layers = params_spec.n_gpu_layers;
+            params_dft.cache_type_k = params_spec.cache_type_k;
+            params_dft.cache_type_v = params_spec.cache_type_v;
 
-            params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
-            params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
-            params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
+            if (params_spec.cpuparams.n_threads > 0) {
+                params_dft.cpuparams.n_threads       = params_spec.cpuparams.n_threads;
+                params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
+            }
 
-            llama_init_dft = common_init_from_params(params_dft);
+            params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
 
-            model_dft = llama_init_dft->model();
+            auto mparams_dft = common_model_params_to_llama(params_dft);
 
+            model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
             if (model_dft == nullptr) {
-                SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
+                SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
                 return false;
             }
 
-            vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
-            if (!vocab_dft_compatible) {
-                SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
-            }
-
-            const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
-
-            cparams_dft = common_context_params_to_llama(params_dft);
-            cparams_dft.n_batch = n_ctx_dft;
-
-            // the context is not needed - we will create one for each slot
-            llama_init_dft->free_context();
+            params_base.speculative.model_dft = model_dft.get();
+            params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
         }
 
         std::string & mmproj_path = params_base.mmproj.path;
@@ -695,6 +680,7 @@ private:
             }
 
             mtmd_context_params mparams = mtmd_context_params_default();
+
             mparams.use_gpu          = params_base.mmproj_use_gpu;
             mparams.print_timings    = false;
             mparams.n_threads        = params_base.cpuparams.n_threads;
@@ -702,6 +688,7 @@ private:
             mparams.warmup           = params_base.warmup;
             mparams.image_min_tokens = params_base.image_min_tokens;
             mparams.image_max_tokens = params_base.image_max_tokens;
+
             mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
             if (mctx == nullptr) {
                 SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
@@ -718,11 +705,6 @@ private:
                 params_base.n_cache_reuse = 0;
                 SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
             }
-
-            if (params_base.has_speculative()) {
-                SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
-                return false;
-            }
         }
 
         if (!llama_memory_can_shift(llama_get_memory(ctx))) {
@@ -757,29 +739,24 @@ private:
         for (int i = 0; i < params_base.n_parallel; i++) {
             server_slot slot;
 
-            slot.id = i;
-            slot.ctx = ctx;
+            slot.id    = i;
+            slot.ctx   = ctx;
             slot.n_ctx = n_ctx_slot;
-            slot.mctx = mctx;
-            slot.prompt.tokens.has_mtmd = mctx != nullptr;
-
-            if (model_dft) {
-                slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
 
-                // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
-                slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
-                if (slot.ctx_dft == nullptr) {
-                    SRV_ERR("%s", "failed to create draft context\n");
-                    return false;
-                }
+            slot.mctx                   = mctx;
+            slot.prompt.tokens.has_mtmd = mctx != nullptr;
 
-                slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
-                if (slot.spec == nullptr) {
-                    SRV_ERR("%s", "failed to create speculator\n");
-                    return false;
-                }
-                for (auto & pair : params_base.speculative.replacements) {
-                    common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
+            // try speculative decoding
+            {
+                slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
+                if (slot.spec) {
+                    if (mctx) {
+                        SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
+                        return false;
+                    }
+                    SRV_WRN("%s", "speculative decoding context initialized\n");
+                } else {
+                    SRV_WRN("%s", "speculative decoding context not initialized\n");
                 }
             }
 
@@ -1059,7 +1036,7 @@ private:
         return res;
     }
 
-    std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
+    std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
         std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
         for (size_t i = 0; i < output.size(); ++i) {
             auto it = config.find(i);
@@ -1162,7 +1139,7 @@ private:
             backend_sampling &= task.params.sampling.backend_sampling;
 
             // TODO: speculative decoding requires multiple samples per batch - not supported yet
-            backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
+            backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
 
             // TODO: getting post/pre sampling logits is not yet supported with backend sampling
             backend_sampling &= !need_logits;
@@ -1179,14 +1156,6 @@ private:
             slot.smpl.reset();
         }
 
-        // initialize draft batch
-        // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
-        if (slot.ctx_dft) {
-            llama_batch_free(slot.batch_spec);
-
-            slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
-        }
-
         slot.task = std::make_unique<const server_task>(std::move(task));
 
         slot.state = slot.task->is_child()
@@ -2059,19 +2028,23 @@ private:
             // generate draft tokens in speculative decoding mode
             // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
             //       perform the speculative drafting for all sequences at the same time in a single batch
-            int n_draft_max = slot.get_n_draft_max();
+            const int n_draft_max = slot.get_n_draft_max();
             if (n_draft_max > 0) {
                 if (mctx) {
                     // we should never reach this, as speculative is automatically disabled if mmproj is loaded
                     GGML_ABORT("not supported by multimodal");
                 }
 
-                struct common_speculative_params params_spec;
-                params_spec.n_draft = n_draft_max;
-                params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
-                params_spec.p_min   = slot.task->params.speculative.p_min;
                 const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
-                llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
+
+                const auto & params_spec = slot.task->params.speculative;
+
+                llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
+
+                if (draft.size() > (size_t) n_draft_max) {
+                    SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
+                    draft.resize(n_draft_max);
+                }
 
                 // add the sampled token to the batch
                 slot.i_batch_dft.push_back(batch.n_tokens);
@@ -2742,6 +2715,10 @@ private:
 
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
+
+                    if (slot.can_speculate()) {
+                        common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
+                    }
                 } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots
                 }
@@ -2813,6 +2790,9 @@ private:
                 // update how many tokens out of those tested were accepted
                 slot.n_draft_accepted += ids.size() - 1;
 
+                // inform the speculative decoding about the number of accepted tokens
+                common_speculative_accept(slot.spec, ids.size() - 1);
+
                 // rollback to the state before sampling the draft tokens
                 slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
 
index 799e341d373a40268bfef150ae2a4dc316de21d4..2d25db63b74e13be0146d03e1646ea0b8e47fd6d 100644 (file)
@@ -5,6 +5,7 @@
 #include "llama.h"
 #include "chat.h"
 #include "sampling.h"
+#include "speculative.h"
 #include "json-schema-to-grammar.h"
 
 using json = nlohmann::ordered_json;
@@ -76,6 +77,11 @@ json task_params::to_json(bool only_metrics) const {
             {"speculative.n_max",         speculative.n_max},
             {"speculative.n_min",         speculative.n_min},
             {"speculative.p_min",         speculative.p_min},
+            {"speculative.type",          common_speculative_type_to_str(speculative.type)},
+            {"speculative.ngram_size_n",  speculative.ngram_size_n},
+            {"speculative.ngram_size_m",  speculative.ngram_size_m},
+            {"speculative.ngram_c_rate",  speculative.ngram_check_rate},
+            {"speculative.ngram_m_hits",  speculative.ngram_min_hits},
             {"timings_per_token",         timings_per_token},
             {"post_sampling_probs",       post_sampling_probs},
             {"backend_sampling",          sampling.backend_sampling},
@@ -135,6 +141,11 @@ json task_params::to_json(bool only_metrics) const {
         {"speculative.n_max",         speculative.n_max},
         {"speculative.n_min",         speculative.n_min},
         {"speculative.p_min",         speculative.p_min},
+        {"speculative.type",          common_speculative_type_to_str(speculative.type)},
+        {"speculative.ngram_size_n",  speculative.ngram_size_n},
+        {"speculative.ngram_size_m",  speculative.ngram_size_m},
+        {"speculative.ngram_c_rate",  speculative.ngram_check_rate},
+        {"speculative.ngram_m_hits",  speculative.ngram_min_hits},
         {"timings_per_token",         timings_per_token},
         {"post_sampling_probs",       post_sampling_probs},
         {"backend_sampling",          sampling.backend_sampling},
@@ -242,6 +253,18 @@ task_params server_task::params_from_json_cmpl(
     params.speculative.n_min = std::max(params.speculative.n_min, 0);
     params.speculative.n_max = std::max(params.speculative.n_max, 0);
 
+    params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
+
+    params.speculative.ngram_size_n     = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
+    params.speculative.ngram_size_m     = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
+    params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate);
+    params.speculative.ngram_min_hits   = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
+
+    params.speculative.ngram_size_n     = std::max(std::min(1, (int) params.speculative.ngram_size_n),     1024);
+    params.speculative.ngram_size_m     = std::max(std::min(1, (int) params.speculative.ngram_size_m),     1024);
+    params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024);
+    params.speculative.ngram_min_hits   = std::max(std::min(1, (int) params.speculative.ngram_min_hits),   1024);
+
     // Use OpenAI API logprobs only if n_probs wasn't provided
     if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
         params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);