]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : PoC for speeding-up inference via speculative sampling (#2926)
authorGeorgi Gerganov <redacted>
Sun, 3 Sep 2023 12:12:08 +0000 (15:12 +0300)
committerGitHub <redacted>
Sun, 3 Sep 2023 12:12:08 +0000 (15:12 +0300)
* speculative : initial example

* speculative : print encoding speed

* speculative : add --draft CLI arg

common/common.cpp
common/common.h
examples/CMakeLists.txt
examples/main/main.cpp
examples/speculative/CMakeLists.txt [new file with mode: 0644]
examples/speculative/speculative.cpp [new file with mode: 0644]

index a1c3dc7805361a6fae3d359b50e8642e8bfde8be..313821375df020f76b61f242b87c6a0eddfc748e 100644 (file)
@@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_keep = std::stoi(argv[i]);
+        } else if (arg == "--draft") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_draft = std::stoi(argv[i]);
         } else if (arg == "--chunks") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -317,6 +323,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.model = argv[i];
+        } else if (arg == "-md" || arg == "--model-draft") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.model_draft = argv[i];
         } else if (arg == "-a" || arg == "--alias") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -638,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  --hellaswag           compute HellaSwag score over random tasks from datafile supplied with -f\n");
     fprintf(stdout, "  --hellaswag-tasks N   number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
     fprintf(stdout, "  --keep N              number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
+    fprintf(stdout, "  --draft N             number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
     fprintf(stdout, "  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
     if (llama_mlock_supported()) {
         fprintf(stdout, "  --mlock               force system to keep model in RAM rather than swapping or compressing\n");
@@ -669,6 +682,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n");
     fprintf(stdout, "  -m FNAME, --model FNAME\n");
     fprintf(stdout, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stdout, "  -md FNAME, --model-draft FNAME\n");
+    fprintf(stdout, "                        draft model for speculative decoding (default: %s)\n", params.model.c_str());
     fprintf(stdout, "  -ld LOGDIR, --logdir LOGDIR\n");
     fprintf(stdout, "                        path under which to save YAML logs (no logging if unset)\n");
     fprintf(stdout, "\n");
@@ -832,6 +847,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
     return result;
 }
 
+//
+// Sampling utils
+//
+
+llama_token llama_sample_token(
+                  struct llama_context * ctx,
+                  struct llama_context * ctx_guidance,
+                  struct llama_grammar * grammar,
+               const struct gpt_params & params,
+        const std::vector<llama_token> & last_tokens,
+         std::vector<llama_token_data> & candidates,
+                                   int   idx) {
+    const int n_ctx   = llama_n_ctx(ctx);
+    const int n_vocab = llama_n_vocab(ctx);
+
+    const float   temp            = params.temp;
+    const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k;
+    const float   top_p           = params.top_p;
+    const float   tfs_z           = params.tfs_z;
+    const float   typical_p       = params.typical_p;
+    const int32_t repeat_last_n   = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
+    const float   repeat_penalty  = params.repeat_penalty;
+    const float   alpha_presence  = params.presence_penalty;
+    const float   alpha_frequency = params.frequency_penalty;
+    const int     mirostat        = params.mirostat;
+    const float   mirostat_tau    = params.mirostat_tau;
+    const float   mirostat_eta    = params.mirostat_eta;
+    const bool    penalize_nl     = params.penalize_nl;
+
+    llama_token id = 0;
+
+    float * logits = llama_get_logits(ctx) + idx * n_vocab;
+
+    // Apply params.logit_bias map
+    for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+        logits[it->first] += it->second;
+    }
+
+    candidates.clear();
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+    }
+
+    llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
+
+    if (ctx_guidance) {
+        llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
+    }
+
+    // apply penalties
+    if (!last_tokens.empty()) {
+        const float nl_logit = logits[llama_token_nl(ctx)];
+        const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
+
+        llama_sample_repetition_penalty(ctx, &cur_p,
+                last_tokens.data() + last_tokens.size() - last_n_repeat,
+                last_n_repeat, repeat_penalty);
+        llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
+                last_tokens.data() + last_tokens.size() - last_n_repeat,
+                last_n_repeat, alpha_frequency, alpha_presence);
+
+        if (!penalize_nl) {
+            for (size_t idx = 0; idx < cur_p.size; idx++) {
+                if (cur_p.data[idx].id == llama_token_nl(ctx)) {
+                    cur_p.data[idx].logit = nl_logit;
+                    break;
+                }
+            }
+        }
+    }
+
+    if (grammar != NULL) {
+        llama_sample_grammar(ctx, &cur_p, grammar);
+    }
+
+    if (temp <= 0) {
+        // Greedy sampling
+        id = llama_sample_token_greedy(ctx, &cur_p);
+    } else {
+        if (mirostat == 1) {
+            static float mirostat_mu = 2.0f * mirostat_tau;
+            const int mirostat_m = 100;
+            llama_sample_temperature(ctx, &cur_p, temp);
+            id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+        } else if (mirostat == 2) {
+            static float mirostat_mu = 2.0f * mirostat_tau;
+            llama_sample_temperature(ctx, &cur_p, temp);
+            id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+        } else {
+            // Temperature sampling
+            llama_sample_top_k      (ctx, &cur_p, top_k, 1);
+            llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1);
+            llama_sample_typical    (ctx, &cur_p, typical_p, 1);
+            llama_sample_top_p      (ctx, &cur_p, top_p, 1);
+            llama_sample_temperature(ctx, &cur_p, temp);
+
+            {
+                const int n_top = 10;
+                LOG("top %d candidates:\n", n_top);
+
+                for (int i = 0; i < n_top; i++) {
+                    const llama_token id = cur_p.data[i].id;
+                    LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
+                }
+            }
+
+            id = llama_sample_token(ctx, &cur_p);
+
+            LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
+        }
+    }
+    // printf("`%d`", candidates_p.size);
+
+    if (grammar != NULL) {
+        llama_grammar_accept_token(ctx, grammar, id);
+    }
+
+    return id;
+}
+
+//
+// YAML utils
+//
+
 // returns true if successful, false otherwise
 bool create_directory_with_parents(const std::string & path) {
 #ifdef _WIN32
@@ -1070,6 +1209,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
     fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
     fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
+    fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
     fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
     fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
     fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);
index 5a379688ee52928ea89a0f17eabdf466a4d87dac..105fb09e4924da1e4e8c5b9ff385baec9e4f04bb 100644 (file)
@@ -32,6 +32,7 @@ struct gpt_params {
     int32_t n_ctx                           = 512;  // context size
     int32_t n_batch                         = 512;  // batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep                          = 0;    // number of tokens to keep from initial prompt
+    int32_t n_draft                         = 16;   // number of tokens to draft during speculative decoding
     int32_t n_chunks                        = -1;   // max number of chunks to process (-1 = unlimited)
     int32_t n_gpu_layers                    = 0;    // number of layers to store in VRAM
     int32_t main_gpu                        = 0;    // the GPU that is used for scratch and small tensors
@@ -63,6 +64,7 @@ struct gpt_params {
     float       cfg_scale         = 1.f;   // How strong is guidance
 
     std::string model             = "models/7B/ggml-model-f16.gguf"; // model path
+    std::string model_draft       = "";                              // draft model for speculative decoding
     std::string model_alias       = "unknown"; // model alias
     std::string prompt            = "";
     std::string path_prompt_cache = "";  // path to file for saving/loading prompt eval state
@@ -156,6 +158,40 @@ std::string llama_detokenize_bpe(
                          llama_context * ctx,
         const std::vector<llama_token> & tokens);
 
+//
+// Sampling utils
+//
+
+// this is a common sampling function used across the examples for convenience
+// it can serve as a starting point for implementing your own sampling function
+//
+// required:
+//  - ctx:    context to use for sampling
+//  - params: sampling parameters
+//
+// optional:
+//  - ctx_guidance:  context to use for classifier-free guidance, ignore if NULL
+//  - grammar:       grammar to use for sampling, ignore if NULL
+//  - last_tokens:   needed for repetition penalty, ignore if empty
+//  - idx:           sample from llama_get_logits(ctx) + idx * n_vocab
+//
+// returns:
+//  - token:      sampled token
+//  - candidates: vector of candidate tokens
+//
+llama_token llama_sample_token(
+                  struct llama_context * ctx,
+                  struct llama_context * ctx_guidance,
+                  struct llama_grammar * grammar,
+               const struct gpt_params & params,
+        const std::vector<llama_token> & last_tokens,
+         std::vector<llama_token_data> & candidates,
+                                   int   idx = 0);
+
+//
+// YAML utils
+//
+
 bool create_directory_with_parents(const std::string & path);
 void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
 void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
index 6e65eb0876c7e361019712894b13486ec8b4f503..884c4276422ebf44db08737d210aaf599296d9f2 100644 (file)
@@ -23,6 +23,7 @@ else()
     add_subdirectory(train-text-from-scratch)
     add_subdirectory(convert-llama2c-to-ggml)
     add_subdirectory(simple)
+    add_subdirectory(speculative)
     add_subdirectory(embd-input)
     add_subdirectory(llama-bench)
     add_subdirectory(beam-search)
index db98312ca1abac2cd3ac1d43e47237d7d680760e..922b9a9807bb735ab41aa0be73174908af126a08 100644 (file)
@@ -116,7 +116,7 @@ int main(int argc, char ** argv) {
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("main", "log"));
     LOG_TEE("Log start\n");
-    log_dump_cmdline(argc,argv);
+    log_dump_cmdline(argc, argv);
 #endif // LOG_DISABLE_LOGS
 
     // TODO: Dump params ?
@@ -425,8 +425,9 @@ int main(int argc, char ** argv) {
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     LOG_TEE("\n\n");
 
+    struct llama_grammar * grammar = NULL;
     grammar_parser::parse_state parsed_grammar;
-    llama_grammar *             grammar = NULL;
+
     if (!params.grammar.empty()) {
         parsed_grammar = grammar_parser::parse(params.grammar.c_str());
         // will be empty (default) if there are parse errors
@@ -450,8 +451,8 @@ int main(int argc, char ** argv) {
     }
 
     // TODO: replace with ring-buffer
-    std::vector<llama_token> last_n_tokens(n_ctx);
-    std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+    std::vector<llama_token> last_tokens(n_ctx);
+    std::fill(last_tokens.begin(), last_tokens.end(), 0);
 
     if (params.interactive) {
         const char *control_message;
@@ -492,6 +493,11 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd;
     std::vector<llama_token> embd_guidance;
 
+    const int n_vocab = llama_n_vocab(ctx);
+
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict
         if (embd.size() > 0) {
@@ -529,8 +535,8 @@ int main(int argc, char ** argv) {
 
                 LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
 
-                // insert n_left/2 tokens at the start of embd from last_n_tokens
-                embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
+                // insert n_left/2 tokens at the start of embd from last_tokens
+                embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
 
                 LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
 
@@ -629,20 +635,6 @@ int main(int argc, char ** argv) {
         embd_guidance.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
-            const float   temp            = params.temp;
-            const int32_t top_k           = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
-            const float   top_p           = params.top_p;
-            const float   tfs_z           = params.tfs_z;
-            const float   typical_p       = params.typical_p;
-            const int32_t repeat_last_n   = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
-            const float   repeat_penalty  = params.repeat_penalty;
-            const float   alpha_presence  = params.presence_penalty;
-            const float   alpha_frequency = params.frequency_penalty;
-            const int     mirostat        = params.mirostat;
-            const float   mirostat_tau    = params.mirostat_tau;
-            const float   mirostat_eta    = params.mirostat_eta;
-            const bool    penalize_nl     = params.penalize_nl;
-
             // optionally save the session on first sample (for faster prompt loading next time)
             if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
                 need_to_save_session = false;
@@ -651,98 +643,12 @@ int main(int argc, char ** argv) {
                 LOG("saved session to %s\n", path_session.c_str());
             }
 
-            llama_token id = 0;
-
-            {
-                auto logits  = llama_get_logits(ctx);
-                auto n_vocab = llama_n_vocab(ctx);
-
-                // Apply params.logit_bias map
-                for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
-                    logits[it->first] += it->second;
-                }
-
-                std::vector<llama_token_data> candidates;
-                candidates.reserve(n_vocab);
-                for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-                    candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-                }
-
-                llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
-
-                if (ctx_guidance) {
-                    llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
-                }
-
-                // Apply penalties
-                float nl_logit = logits[llama_token_nl(ctx)];
-                auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
-                llama_sample_repetition_penalty(ctx, &cur_p,
-                    last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
-                    last_n_repeat, repeat_penalty);
-                llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
-                    last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
-                    last_n_repeat, alpha_frequency, alpha_presence);
-                if (!penalize_nl) {
-                    for (size_t idx = 0; idx < cur_p.size; idx++) {
-                        if (cur_p.data[idx].id == llama_token_nl(ctx)) {
-                            cur_p.data[idx].logit = nl_logit;
-                            break;
-                        }
-                    }
-                }
-
-                if (grammar != NULL) {
-                    llama_sample_grammar(ctx, &cur_p, grammar);
-                }
-
-                if (temp <= 0) {
-                    // Greedy sampling
-                    id = llama_sample_token_greedy(ctx, &cur_p);
-                } else {
-                    if (mirostat == 1) {
-                        static float mirostat_mu = 2.0f * mirostat_tau;
-                        const int mirostat_m = 100;
-                        llama_sample_temperature(ctx, &cur_p, temp);
-                        id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
-                    } else if (mirostat == 2) {
-                        static float mirostat_mu = 2.0f * mirostat_tau;
-                        llama_sample_temperature(ctx, &cur_p, temp);
-                        id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
-                    } else {
-                        // Temperature sampling
-                        llama_sample_top_k      (ctx, &cur_p, top_k, 1);
-                        llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1);
-                        llama_sample_typical    (ctx, &cur_p, typical_p, 1);
-                        llama_sample_top_p      (ctx, &cur_p, top_p, 1);
-                        llama_sample_temperature(ctx, &cur_p, temp);
-
-                        {
-                            const int n_top = 10;
-                            LOG("top %d candidates:\n", n_top);
-
-                            for (int i = 0; i < n_top; i++) {
-                                const llama_token id = cur_p.data[i].id;
-                                LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
-                            }
-                        }
-
-                        id = llama_sample_token(ctx, &cur_p);
+            const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
 
-                        LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
-                    }
-                }
-                // printf("`%d`", candidates_p.size);
+            last_tokens.erase(last_tokens.begin());
+            last_tokens.push_back(id);
 
-                if (grammar != NULL) {
-                    llama_grammar_accept_token(ctx, grammar, id);
-                }
-
-                last_n_tokens.erase(last_n_tokens.begin());
-                last_n_tokens.push_back(id);
-
-                LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens));
-            }
+            LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
 
             embd.push_back(id);
 
@@ -758,8 +664,8 @@ int main(int argc, char ** argv) {
             LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
             while ((int) embd_inp.size() > n_consumed) {
                 embd.push_back(embd_inp[n_consumed]);
-                last_n_tokens.erase(last_n_tokens.begin());
-                last_n_tokens.push_back(embd_inp[n_consumed]);
+                last_tokens.erase(last_tokens.begin());
+                last_tokens.push_back(embd_inp[n_consumed]);
                 ++n_consumed;
                 if ((int) embd.size() >= params.n_batch) {
                     break;
@@ -792,7 +698,7 @@ int main(int argc, char ** argv) {
             // check for reverse prompt
             if (params.antiprompt.size()) {
                 std::string last_output;
-                for (auto id : last_n_tokens) {
+                for (auto id : last_tokens) {
                     last_output += llama_token_to_piece(ctx, id);
                 }
 
@@ -823,7 +729,7 @@ int main(int argc, char ** argv) {
             }
 
             // deal with end of text token in interactive mode
-            if (last_n_tokens.back() == llama_token_eos(ctx)) {
+            if (last_tokens.back() == llama_token_eos(ctx)) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -925,7 +831,7 @@ int main(int argc, char ** argv) {
                     if (grammar != NULL) {
                         llama_grammar_free(grammar);
 
-                        std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules());
+                        std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
                         grammar = llama_grammar_init(
                             grammar_rules.data(), grammar_rules.size(),
                             parsed_grammar.symbol_ids.at("root"));
diff --git a/examples/speculative/CMakeLists.txt b/examples/speculative/CMakeLists.txt
new file mode 100644 (file)
index 0000000..6c5c945
--- /dev/null
@@ -0,0 +1,8 @@
+set(TARGET speculative)
+add_executable(${TARGET} speculative.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+  add_dependencies(${TARGET} BUILD_INFO)
+endif()
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
new file mode 100644 (file)
index 0000000..f0400c1
--- /dev/null
@@ -0,0 +1,234 @@
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include "build-info.h"
+
+#include "common.h"
+#include "llama.h"
+
+#include <cmath>
+#include <cstdio>
+#include <string>
+#include <vector>
+
+int main(int argc, char ** argv) {
+    gpt_params params;
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.model_draft.empty()) {
+        fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
+        return 1;
+    }
+
+#ifndef LOG_DISABLE_LOGS
+    log_set_target(log_filename_generator("speculative", "log"));
+    LOG_TEE("Log start\n");
+    log_dump_cmdline(argc, argv);
+#endif // LOG_DISABLE_LOGS
+
+    // init llama.cpp
+    llama_backend_init(params.numa);
+
+    llama_model * model_tgt = NULL;
+    llama_model * model_dft = NULL;
+
+    llama_context * ctx_tgt = NULL;
+    llama_context * ctx_dft = NULL;
+
+    // load the target model
+    params.perplexity = true; // HACK: enable logits_all = true
+    std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
+
+    // load the draft model
+    params.model = params.model_draft;
+    std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
+
+    // tokenize the prompt
+    std::vector<llama_token> inp;
+    inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
+
+    const int max_context_size     = llama_n_ctx(ctx_tgt);
+    const int max_tokens_list_size = max_context_size - 4;
+
+    if ((int) inp.size() > max_tokens_list_size) {
+        fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
+        return 1;
+    }
+
+    fprintf(stderr, "\n\n");
+
+    for (auto id : inp) {
+        fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
+    }
+
+    fflush(stderr);
+
+    const int n_input = inp.size();
+
+    const auto t_enc_start = ggml_time_us();
+
+    // eval the prompt with both models
+    llama_eval(ctx_tgt,  inp.data(), int(inp.size() - 1), 0, params.n_threads);
+    llama_eval(ctx_tgt, &inp.back(),      1, inp.size() - 1, params.n_threads);
+    llama_eval(ctx_dft,  inp.data(),     int(inp.size()), 0, params.n_threads);
+
+    const auto t_enc_end = ggml_time_us();
+
+    // the 2 models should have the same vocab
+    const int n_ctx   = llama_n_ctx(ctx_tgt);
+    const int n_vocab = llama_n_vocab(ctx_tgt);
+    //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
+
+    // how many tokens to draft each time
+    const int n_draft = params.n_draft;
+
+    int n_predict = 0;
+    int n_drafted = 0;
+    int n_accept  = 0;
+
+    int n_past_tgt = inp.size();
+    int n_past_dft = inp.size();
+
+    std::vector<llama_token> drafted;
+
+    std::vector<llama_token> last_tokens(n_ctx);
+    std::fill(last_tokens.begin(), last_tokens.end(), 0);
+
+    for (auto & id : inp) {
+        last_tokens.erase(last_tokens.begin());
+        last_tokens.push_back(id);
+    }
+
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+
+    // used to determine end of generation
+    bool has_eos = false;
+
+    const auto t_dec_start = ggml_time_us();
+
+    while (true) {
+        LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
+
+        // sample from the drafted tokens if any
+        int i_dft = 0;
+        while (true) {
+            const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
+
+            last_tokens.erase(last_tokens.begin());
+            last_tokens.push_back(id);
+
+            //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
+
+            const std::string token_str = llama_token_to_piece(ctx_tgt, id);
+            printf("%s", token_str.c_str());
+            fflush(stdout);
+
+            if (id == llama_token_eos(ctx_tgt)) {
+                has_eos = true;
+            }
+
+            ++n_predict;
+
+            if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
+                LOG("drafted token %d accepted\n", id);
+                ++n_accept;
+                ++n_past_tgt;
+                ++n_past_dft;
+                ++i_dft;
+
+                continue;
+            }
+
+            // the drafted token was rejected or we are out of drafted tokens
+            llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
+            ++n_past_dft;
+
+            drafted.clear();
+            drafted.push_back(id);
+
+            break;
+        }
+
+        if (n_predict > params.n_predict || has_eos) {
+            break;
+        }
+
+        // sample n_draft tokens from the draft model picking the best token
+        int n_past_cur = n_past_dft;
+        for (int i = 0; i < n_draft; ++i) {
+            float * logits = llama_get_logits(ctx_dft);
+
+            candidates.clear();
+            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+                candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+            }
+
+            llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
+
+            // computes softmax and sorts the candidates
+            llama_sample_softmax(ctx_dft, &cur_p);
+
+            for (int i = 0; i < 3; ++i) {
+                LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p);
+            }
+
+            // too low probability, stop drafting
+            if (cur_p.data[0].p < 2*cur_p.data[1].p) {
+                break;
+            }
+
+            drafted.push_back(cur_p.data[0].id);
+            ++n_drafted;
+
+            if (i < n_draft - 1) {
+                // evaluate the drafted token on the draft model
+                llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
+                ++n_past_cur;
+            }
+        }
+
+        // evaluate the target model on the drafted tokens
+        llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
+        ++n_past_tgt;
+
+        drafted.erase(drafted.begin());
+    }
+
+    auto t_dec_end = ggml_time_us();
+
+    LOG_TEE("\n\n");
+
+    LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
+    LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
+
+    // TODO: make sure these numbers are computed correctly
+    LOG_TEE("\n");
+    LOG_TEE("n_draft   = %d\n", n_draft);
+    LOG_TEE("n_predict = %d\n", n_predict);
+    LOG_TEE("n_drafted = %d\n", n_drafted);
+    LOG_TEE("n_accept  = %d\n", n_accept);
+    LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
+
+    LOG_TEE("\ndraft:\n");
+    llama_print_timings(ctx_dft);
+
+    LOG_TEE("\ntarget:\n");
+    llama_print_timings(ctx_tgt);
+
+    llama_free(ctx_tgt);
+    llama_free_model(model_tgt);
+
+    llama_free(ctx_dft);
+    llama_free_model(model_dft);
+
+    llama_backend_free();
+
+    fprintf(stderr, "\n\n");
+
+    return 0;
+}