]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : implement universal assisted decoding (#12635)
authorg2mt <redacted>
Thu, 31 Jul 2025 12:25:23 +0000 (05:25 -0700)
committerGitHub <redacted>
Thu, 31 Jul 2025 12:25:23 +0000 (14:25 +0200)
* llama-server : implement universal assisted decoding

* Erase prompt tail for kv-cache

* set vocab_dft_compatible in common_speculative

* rename ctx_main to ctx_tgt

* move vocab_dft_compatible to spec struct

* clear mem_dft, remove mem

* detokenize id_last for incompatible models

* update comment

* add --spec-replace flag

* accept special tokens when translating between draft/main models

* Escape spec-replace

* clamp draft result to size to params.n_draft

* fix comment

* clean up code

* restore old example

* log common_speculative_are_compatible in speculative example

* fix

* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
common/arg.cpp
common/common.h
common/speculative.cpp
common/speculative.h
examples/speculative-simple/speculative-simple.cpp
tools/server/server.cpp

index 74137d2db959d4beb87bfb42623a0477b6e9d162..7744fd6c48876078d55d54467f601dc292d2ae8d 100644 (file)
@@ -977,6 +977,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
             string_process_escapes(seq_breaker);
         }
+        for (auto & pair : params.speculative.replacements) {
+            string_process_escapes(pair.first);
+            string_process_escapes(pair.second);
+        }
     }
 
     if (!params.kv_overrides.empty()) {
@@ -3249,6 +3253,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.speculative.model.path = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
+    add_opt(common_arg(
+        {"--spec-replace"}, "TARGET", "DRAFT",
+        "translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
+        [](common_params & params, const std::string & tgt, const std::string & dft) {
+            params.speculative.replacements.push_back({ tgt, dft });
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"-ctkd", "--cache-type-k-draft"}, "TYPE",
         string_format(
index 38129b99d511ff4b9f8b9826e9e7fc4fe8ba0c64..f5acf37ff9fd7aad93c3693395c72da4945fd641 100644 (file)
@@ -201,6 +201,7 @@ struct common_params_speculative {
     int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
     float   p_split      =  0.1f; // speculative decoding split probability
     float   p_min        = 0.75f; // minimum speculative decoding probability (greedy)
+    std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
 
     ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
     ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
index 843bd1ddbdbd79090a9a119c84e5f06bdac394be..262b2c23e720f7ce9cea80774e6c7c681e172a02 100644 (file)
@@ -1,30 +1,39 @@
 #include "speculative.h"
 
+#include "ggml.h"
+#include "llama.h"
 #include "log.h"
 #include "common.h"
 #include "sampling.h"
 
 #include <cstring>
 #include <algorithm>
+#include <map>
 
 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
 
 struct common_speculative {
-    struct llama_context * ctx;
+    struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
+    struct llama_context * ctx_dft;
     struct common_sampler * smpl;
 
     llama_batch batch;
-    llama_tokens prompt;
+    llama_tokens prompt_dft;
+    bool vocab_dft_compatible = true; // whether retokenization is needed
+    std::map<std::string, std::string> tgt_dft_replacements = {};
 };
 
 struct common_speculative * common_speculative_init(
+        struct llama_context * ctx_tgt,
         struct llama_context * ctx_dft) {
     auto * result = new common_speculative {
-        /* .ctx    = */ ctx_dft,
-        /* .smpl   = */ nullptr,
-        /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
-        /* .prompt = */ {},
+        /* .ctx_tgt    = */ ctx_tgt,
+        /* .ctx_dft    = */ ctx_dft,
+        /* .smpl       = */ nullptr,
+        /* .batch      = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
+        /* .prompt_dft = */ {},
+        /* .vocab_dft_compatible = */ false,
     };
 
     // TODO: optimize or pass from outside?
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
     }
 #endif
 
+    result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
+    LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
+
     return result;
 }
 
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
 }
 
 bool common_speculative_are_compatible(
-        const struct llama_context * ctx_tgt,
-        const struct llama_context * ctx_dft) {
+    const struct llama_context * ctx_tgt,
+    const struct llama_context * ctx_dft) {
     const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
     const struct llama_model * model_dft = llama_get_model(ctx_dft);
 
@@ -90,31 +102,32 @@ bool common_speculative_are_compatible(
     LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
 
     if (vocab_type_tgt != vocab_type_dft) {
-        LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
-                     "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
+        LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
+        LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
         return false;
     }
 
-    if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
+    if (
+        llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
         llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
         llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
-        llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
-        LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
-        LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
-        LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
+        llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
+    ) {
+        LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
         return false;
     }
 
     {
         const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
         const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
-
-        const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
+        const int vocab_diff  = n_vocab_tgt > n_vocab_dft
+            ? n_vocab_tgt - n_vocab_dft
+            : n_vocab_dft - n_vocab_tgt;
 
         if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
-            LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
-                         "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
-                    __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
+            LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
+            LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
+                    n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
             return false;
         }
 
@@ -122,8 +135,8 @@ bool common_speculative_are_compatible(
             const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
             const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
             if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
-                LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
-                             "token %d content differs - target '%s', draft '%s'\n", __func__, i,
+                LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
+                LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
                         common_token_to_piece(ctx_tgt, i).c_str(),
                         common_token_to_piece(ctx_dft, i).c_str());
                 return false;
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
     return true;
 }
 
+void common_speculative_add_replacement_tgt_dft(
+        struct common_speculative * spec,
+        const char *source, const char *dest) {
+    spec->tgt_dft_replacements[source] = dest;
+}
+
+static std::string replace_to_dft(
+        struct common_speculative * spec,
+        const std::string& input) {
+    std::string result = input;
+    for (const auto & pair : spec->tgt_dft_replacements) {
+        size_t pos = result.find(pair.first);
+        while (pos != std::string::npos) {
+            result.replace(pos, pair.first.length(), pair.second);
+            pos = result.find(pair.first, pos + pair.second.length());
+        }
+    }
+    return result;
+}
+
+static std::string replace_to_tgt(
+        struct common_speculative * spec,
+        const std::string& input) {
+    std::string result = input;
+    for (const auto& pair : spec->tgt_dft_replacements) {
+        size_t pos = result.find(pair.second);
+        while (pos != std::string::npos) {
+            result.replace(pos, pair.second.length(), pair.first);
+            pos = result.find(pair.second, pos + pair.first.length());
+        }
+    }
+    return result;
+}
+
+
 llama_tokens common_speculative_gen_draft(
         struct common_speculative * spec,
         struct common_speculative_params params,
-        const llama_tokens & prompt_tgt,
+        const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
         llama_token id_last) {
     auto & batch  = spec->batch;
-    auto & ctx    = spec->ctx;
+    auto & ctx_tgt = spec->ctx_tgt;
+    auto & ctx_dft = spec->ctx_dft;
     auto & smpl   = spec->smpl;
-    auto & prompt = spec->prompt;
+    auto & prompt_dft = spec->prompt_dft;
 
-    auto * mem = llama_get_memory(ctx);
+    auto * mem_dft = llama_get_memory(ctx_dft);
 
     int reuse_i = 0;
     int reuse_n = 0;
 
-    const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
+    const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
+
+    llama_tokens prompt_tgt_draft_model;
+    if (!spec->vocab_dft_compatible) {
+        std::string text;
+        text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
+        text = replace_to_dft(spec, text);
+        LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
+        prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
+
+        // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
+        const auto * model_tgt = llama_get_model(ctx_tgt);
+        const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
+
+        int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
+        GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
+        text.resize(-n_chars);
+        llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
+        text = replace_to_dft(spec, text);
+
+        LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
+        id_last = common_tokenize(ctx_dft, text, false, true)[0];
+    }
+    // prompt_tgt's tokens will always be compatible with ctx_dft
+    const llama_tokens &prompt_tgt =
+        spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
 
     const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
 
     // reuse as much as possible from the old draft context
     // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
-    for (int i = 0; i < (int) prompt.size(); ++i) {
+    for (int i = 0; i < (int) prompt_dft.size(); ++i) {
         int cur = 0;
         while (i_start + cur < (int) prompt_tgt.size() &&
-               i       + cur < (int) prompt.size() &&
-               prompt_tgt[i_start + cur] == prompt[i + cur]) {
+               i       + cur < (int) prompt_dft.size() &&
+               prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
             cur++;
         }
 
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
         }
     }
 
-    LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
+    LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
 
     llama_tokens result;
     result.reserve(params.n_draft);
 
     if (reuse_n == 0) {
-        llama_memory_clear(mem, false);
-
-        prompt.clear();
+        llama_memory_clear(mem_dft, false);
+        prompt_dft.clear();
     } else {
         // this happens when a previous draft has been discarded (for example, due to being too small), but the
         // target model agreed with it. in this case, we simply pass back the previous results to save compute
-        if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
-            for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
-                result.push_back(prompt[i]);
+        if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
+            for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
+                result.push_back(prompt_dft[i]);
 
                 if (params.n_draft <= (int) result.size()) {
                     break;
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
         }
 
         if (reuse_i > 0) {
-            llama_memory_seq_rm (mem, 0, 0, reuse_i);
-            llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
+            llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
+            llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
 
-            prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
+            prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
         }
 
-        if (reuse_n < (int) prompt.size()) {
-            llama_memory_seq_rm (mem, 0, reuse_n, -1);
-
-            prompt.erase(prompt.begin() + reuse_n, prompt.end());
+        if (reuse_n < (int) prompt_dft.size()) {
+            llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
+            prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
         }
     }
 
@@ -214,28 +286,28 @@ llama_tokens common_speculative_gen_draft(
         //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
         common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
 
-        prompt.push_back(prompt_tgt[i]);
+        prompt_dft.push_back(prompt_tgt[i]);
     }
 
     // we should rarely end-up here during normal decoding
     if (batch.n_tokens > 0) {
         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
 
-        llama_decode(ctx, batch);
+        llama_decode(ctx_dft, batch);
     }
 
-    const llama_pos n_past = prompt.size();
+    const llama_pos n_past = prompt_dft.size();
 
     LOG_DBG("%s: n_past = %d\n", __func__, n_past);
 
     common_batch_clear(batch);
     common_batch_add  (batch, id_last, n_past, { 0 }, true);
 
-    prompt.push_back(id_last);
+    prompt_dft.push_back(id_last);
 
-    //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
+    LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
 
-    llama_decode(ctx, batch);
+    llama_decode(ctx_dft, batch);
 
     common_sampler_reset(smpl);
 
@@ -243,13 +315,13 @@ llama_tokens common_speculative_gen_draft(
     for (int i = 0; i < params.n_draft; ++i) {
         common_batch_clear(batch);
 
-        common_sampler_sample(smpl, ctx, 0, true);
+        common_sampler_sample(smpl, ctx_dft, 0, true);
 
         const auto * cur_p = common_sampler_get_candidates(smpl);
 
         for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
             LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
-                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
+                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
         }
 
         // add drafted token for each sequence
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
         common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
 
         // evaluate the drafted tokens on the draft model
-        llama_decode(ctx, batch);
+        llama_decode(ctx_dft, batch);
 
-        prompt.push_back(id);
+        prompt_dft.push_back(id);
     }
 
+    if (!spec->vocab_dft_compatible) {
+        std::string detokenized = common_detokenize(ctx_dft, result, true);
+        detokenized = replace_to_tgt(spec, detokenized);
+        LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
+        result = common_tokenize(ctx_tgt, detokenized, false, true);
+        if (result.size() > (size_t)params.n_draft) {
+            result.resize(params.n_draft);
+        }
+    }
     return result;
 }
index 2b51a70ca1f72ca56ffbcaddfadea33ccd489cbb..e69d7aaa1eb00b06669f17d17ced98803f26e2ca 100644 (file)
@@ -12,7 +12,10 @@ struct common_speculative_params {
     float p_min = 0.75f; // min probability required to accept a token in the draft
 };
 
-struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
+struct common_speculative * common_speculative_init(
+        struct llama_context * ctx_tgt,
+        struct llama_context * ctx_dft
+);
 
 void common_speculative_free(struct common_speculative * spec);
 
@@ -20,6 +23,10 @@ bool common_speculative_are_compatible(
         const struct llama_context * ctx_tgt,
         const struct llama_context * ctx_dft);
 
+void common_speculative_add_replacement_tgt_dft(
+        struct common_speculative * spec,
+        const char *source, const char *dest);
+
 // sample up to n_draft tokens and add them to the batch using the draft model
 llama_tokens common_speculative_gen_draft(
                struct common_speculative * spec,
index 99196c9d047e475eea513055e5c9fdec899b151f..722cd7f40f088b1a535493c7c82fe13f4ecd1df5 100644 (file)
@@ -65,7 +65,7 @@ int main(int argc, char ** argv) {
     ctx_dft   = llama_init_dft.context.get();
 
     if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
-        return 1;
+        LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
     }
 
     // Tokenize the prompt
@@ -130,7 +130,10 @@ int main(int argc, char ** argv) {
     params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
     params_spec.p_min   = p_min;
 
-    struct common_speculative * spec = common_speculative_init(ctx_dft);
+    struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
+    for (auto &pair : params.speculative.replacements) {
+        common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
+    }
 
     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
 
index 9a9b0444746f14cb72d4b88aefd8a46a7529c678..35d6610428efd56aebd52df0a57f9cf011a9255b 100644 (file)
@@ -1929,6 +1929,7 @@ struct server_context {
     mtmd_context * mctx = nullptr;
 
     const llama_vocab * vocab = nullptr;
+    bool vocab_dft_compatible = true;
 
     llama_model * model_dft = nullptr;
 
@@ -2019,10 +2020,9 @@ struct server_context {
                 return false;
             }
 
-            if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
-                SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
-
-                return false;
+            vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
+            if (!vocab_dft_compatible) {
+                SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
             }
 
             const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
@@ -2112,11 +2112,14 @@ struct server_context {
                     return;
                 }
 
-                slot.spec = common_speculative_init(slot.ctx_dft);
+                slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
                 if (slot.spec == nullptr) {
                     SRV_ERR("%s", "failed to create speculator\n");
                     return;
                 }
+                for (auto &pair : params_base.speculative.replacements) {
+                    common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
+                }
             }
 
             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);