]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : implement stochastic speculative sampling (#5625)
authorMinsoo Cheong <redacted>
Mon, 4 Mar 2024 18:24:00 +0000 (03:24 +0900)
committerGitHub <redacted>
Mon, 4 Mar 2024 18:24:00 +0000 (20:24 +0200)
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix #5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

common/common.cpp
common/common.h
common/sampling.cpp
common/sampling.h
examples/speculative/README.md
examples/speculative/speculative.cpp

index dbe7e9229b770ed30ff9305aa52ee18f346dd227..036a981349a690d201034830a612ee653248a21b 100644 (file)
@@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_sequences = std::stoi(argv[i]);
-        } else if (arg == "--p-accept" || arg == "-pa") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.p_accept = std::stof(argv[i]);
         } else if (arg == "--p-split" || arg == "-ps") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
     printf("  -np N, --parallel N   number of parallel sequences to decode (default: %d)\n", params.n_parallel);
     printf("  -ns N, --sequences N  number of sequences to decode (default: %d)\n", params.n_sequences);
-    printf("  -pa N, --p-accept N   speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
     printf("  -ps N, --p-split N    speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
index b2868833be3e657d50dedf7ed824ecc38245407a..977ce419ff93bc6d9096fdb29e18f8c8ea97ab5e 100644 (file)
@@ -53,11 +53,10 @@ struct gpt_params {
     int32_t n_ctx                 = 512;   // context size
     int32_t n_batch               = 512;   // batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep                = 0;     // number of tokens to keep from initial prompt
-    int32_t n_draft               = 8;     // number of tokens to draft during speculative decoding
+    int32_t n_draft               = 5;     // number of tokens to draft during speculative decoding
     int32_t n_chunks              = -1;    // max number of chunks to process (-1 = unlimited)
     int32_t n_parallel            = 1;     // number of parallel sequences to decode
     int32_t n_sequences           = 1;     // number of sequences to decode
-    float   p_accept              = 0.5f;  // speculative decoding accept probability
     float   p_split               = 0.1f;  // speculative decoding split probability
     int32_t n_gpu_layers          = -1;    // number of layers to store in VRAM (-1 - use default)
     int32_t n_gpu_layers_draft    = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
index e67096bea693262d0b4a5d244c37d52d48ae5453..823031febc7e2aec80af69534fb0d0715d0e3edf 100644 (file)
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
     return id;
 }
 
+static llama_token_data_array llama_sample_probability_distribution_impl(
+                  struct llama_sampling_context * ctx_sampling,
+                  struct llama_context * ctx_main,
+                  struct llama_context * ctx_cfg,
+                  const int idx) {
+    const llama_sampling_params & params = ctx_sampling->params;
+
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+    const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+    const float   penalty_repeat  = params.penalty_repeat;
+    const float   penalty_freq    = params.penalty_freq;
+    const float   penalty_present = params.penalty_present;
+    const bool    penalize_nl     = params.penalize_nl;
+
+    auto & prev = ctx_sampling->prev;
+    auto & cur  = ctx_sampling->cur;
+
+    // Get a pointer to the logits
+    float * logits = llama_get_logits_ith(ctx_main, idx);
+
+    // Declare original_logits at the beginning of the function scope
+    std::vector<float> original_logits;
+
+    // apply params.logit_bias map
+    for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+        logits[it->first] += it->second;
+    }
+
+    if (ctx_cfg) {
+        float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+        llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+    }
+
+    cur.clear();
+
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+    }
+
+    llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+
+    // apply penalties
+    const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+    const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+    if (penalty_tokens_used_size) {
+        const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+
+        llama_sample_repetition_penalties(ctx_main, &cur_p,
+                penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+                penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+
+        if (!penalize_nl) {
+            for (size_t idx = 0; idx < cur_p.size; idx++) {
+                if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
+                    cur_p.data[idx].logit = nl_logit;
+                    break;
+                }
+            }
+        }
+    }
+
+    // apply grammar checks
+    if (ctx_sampling->grammar != NULL) {
+        llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+    }
+
+    llama_sample_softmax(ctx_main, &cur_p);
+    return cur_p;
+}
+
 llama_token llama_sampling_sample(
                   struct llama_sampling_context * ctx_sampling,
                   struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
 }
 
+llama_token_data_array llama_sampling_probability_distribution(
+                  struct llama_sampling_context * ctx_sampling,
+                  struct llama_context * ctx_main,
+                  struct llama_context * ctx_cfg,
+                  const int idx) {
+    return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+}
+
 void llama_sampling_accept(
         struct llama_sampling_context * ctx_sampling,
         struct llama_context * ctx_main,
index 95d8753942b40a654b9afdbbe427ac8c8a9b8b79..48b2459d1f9447814b6551d6b31707a2a8f79aeb 100644 (file)
@@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
         struct llama_context * ctx_cfg,
         int idx = 0);
 
+// returns the probability that token of given id will be sampled
+llama_token_data_array llama_sampling_probability_distribution(
+        struct llama_sampling_context * ctx_sampling,
+        struct llama_context * ctx_main,
+        struct llama_context * ctx_cfg,
+        int idx = 0);
+
 void llama_sampling_accept(
         struct llama_sampling_context * ctx_sampling,
         struct llama_context * ctx_main,
index 814efa592d94fabca1b184d8f699c557558b5b07..a6608c5fe8e3adf7fdaa2fda4e822df2631b033f 100644 (file)
@@ -6,3 +6,4 @@ More info:
 
 - https://github.com/ggerganov/llama.cpp/pull/2926
 - https://github.com/ggerganov/llama.cpp/pull/3624
+- https://github.com/ggerganov/llama.cpp/pull/5625
index 3848791d475ad3e248c1aedcb7efae6824684d1a..85bc0a762ad080bc0b5a45230881f1178bd3e3d7 100644 (file)
@@ -5,6 +5,7 @@
 #include <cstdio>
 #include <string>
 #include <vector>
+#include <set>
 
 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100
 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -18,6 +19,7 @@ struct seq_draft {
     std::vector<int> i_batch_tgt;
 
     std::vector<llama_token> tokens;
+    std::vector<std::vector<llama_token_data>> dists;
 
     struct llama_sampling_context * ctx_sampling;
 };
@@ -37,12 +39,15 @@ int main(int argc, char ** argv) {
     // max number of parallel drafting sequences (i.e. tree branches)
     const int n_seq_dft = params.n_parallel;
 
-    // probability threshold for accepting a token from the draft model
-    const float p_accept = params.p_accept;
-
     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
     const float p_split  = params.p_split;
 
+    if (params.seed == LLAMA_DEFAULT_SEED) {
+        params.seed = time(NULL);
+    }
+    std::default_random_engine rng(params.seed);
+    std::uniform_real_distribution<> u_dist;
+
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("speculative", "log"));
     LOG_TEE("Log start\n");
@@ -166,7 +171,9 @@ int main(int argc, char ** argv) {
     std::vector<seq_draft> drafts(n_seq_dft);
 
     params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
-    params.sparams.temp = -1.0f;    // force greedy sampling with probs for the draft model
+    if (params.sparams.temp == 0) {
+        params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
+    }
 
     for (int s = 0; s < n_seq_dft; ++s) {
         drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
@@ -182,12 +189,15 @@ int main(int argc, char ** argv) {
     drafts[0].i_batch_tgt[0] = 0;
 
     while (true) {
+        std::set<int> active_seqs = {};
+
         // print current draft sequences
         for (int s = 0; s < n_seq_dft; ++s) {
             if (!drafts[s].active) {
                 continue;
             }
 
+            active_seqs.insert(s);
             const auto & tokens = drafts[s].tokens;
 
             LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
@@ -196,48 +206,156 @@ int main(int argc, char ** argv) {
         int i_dft  = 0;
         int s_keep = 0;
 
+        llama_token token_id;
+        std::string token_str;
+
+        // loop until we fail to accept a drafted token or we run out of drafted tokens
         while (true) {
-            LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
 
-            // sample from the target model
-            llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+            // check if the target token matches any of the drafts
+            // for stochastic sampling, attempt to match the token with the drafted tokens
+            {
+                bool accept = false;
+                if (params.sparams.temp > 0) {
+                    // stochastic verification
+
+                    llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+                    float p_tgt = 0, p_dft = 0;
+
+                    // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
+
+                    while (active_seqs.size() > 0) {
+                        // randomly select a sequence to verify from active sequences
+                        std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1);
+                        int s = *std::next(active_seqs.begin(), u_int_dist(rng));
+                        if (i_dft >= (int) drafts[s].tokens.size()) {
+                            drafts[s].active = false;
+                            active_seqs.erase(s);
+                            continue;
+                        }
+                        if (accept) {
+                            // if we already accepted a token, we can skip the rest
+                            if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
+                                drafts[s].active = false;
+                                active_seqs.erase(s);
+                            }
+                            continue;
+                        }
+                        LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
+                        float r = u_dist(rng);
+                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
+                        // acquire the token probabilities assigned by the draft and target models
+                        for (size_t i = 0; i < dist_tgt.size; i++) {
+                            if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
+                                p_tgt = dist_tgt.data[i].p;
+                            }
+                            if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
+                                p_dft = dist_dft.data[i].p;
+                            }
+                            if (p_tgt && p_dft) {
+                                break;
+                            }
+                        }
+                        LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
+                        if (r <= p_tgt / p_dft) {
+                            s_keep = s;
+                            accept = true;
+                            token_id = drafts[s].tokens[i_dft];
+                            token_str = llama_token_to_piece(ctx_tgt, token_id);
+                            llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+
+                            LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
+                            break;
+                        } else {
+                            LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
+                            drafts[s].active = false;
+
+                            // calculate residual probability
+                            GGML_ASSERT(dist_tgt.sorted);
+                            GGML_ASSERT(dist_dft.sorted);
+                            float sum_probs = 0.0f;
+
+                            // sort dist by id
+                            std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
+                                return a.id < b.id;
+                            });
+                            std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
+                                return a.id < b.id;
+                            });
+
+                            for (size_t i = 0; i < dist_tgt.size; i++) {
+                                dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+                                sum_probs += dist_tgt.data[i].p;
+                            }
+                            for (size_t i = 0; i < dist_tgt.size; i++) {
+                                dist_tgt.data[i].p /= sum_probs;
+                            }
 
-            llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
+                            // sort dist_tgt by p desc
+                            std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
+                                return a.p > b.p;
+                            });
+                        }
 
-            //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
+                        active_seqs.erase(s);
+                        for(int i = 0; i < n_seq_dft; i++) {
+                            if (i == s) {
+                                continue;
+                            }
+                            if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
+                                // synchronize active status for sequences with the same drafted token
+                                drafts[i].active = drafts[i].active && accept;
+                                if (!drafts[i].active) {
+                                    active_seqs.erase(s);
+                                }
+                            }
+                        }
+                    }
 
-            const std::string token_str = llama_token_to_piece(ctx_tgt, id);
+                    if (!accept) {
+                        // all drafted tokens were rejected
+                        // sample from the target model
+                        LOG("all drafted tokens were rejected, sampling from residual distribution\n");
+                        token_id = llama_sample_token(ctx_tgt, &dist_tgt);
+                        llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+                        token_str = llama_token_to_piece(ctx_tgt, token_id);
+                    }
 
-            if (!params.use_color) {
-                printf("%s", token_str.c_str());
-            }
+                } else {
+                    // greedy verification
 
-            if (id == llama_token_eos(model_tgt)) {
-                has_eos = true;
-            }
+                    // sample from the target model
+                    LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
+                    token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
 
-            ++n_predict;
+                    llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
 
-            // check if the target token matches any of the drafts
-            {
-                bool matches = false;
+                    //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
 
-                for (int s = 0; s < n_seq_dft; ++s) {
-                    if (!drafts[s].active) {
-                        continue;
-                    }
+                    token_str = llama_token_to_piece(ctx_tgt, token_id);
 
-                    if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
-                        LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
+                    for (int s = 0; s < n_seq_dft; ++s) {
+                        if (!drafts[s].active) {
+                            continue;
+                        }
 
-                        s_keep = s;
-                        matches = true;
-                    } else {
-                        drafts[s].active = false;
+                        if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
+                            LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
+
+                            s_keep = s;
+                            accept = true;
+                        } else {
+                            drafts[s].active = false;
+                        }
                     }
                 }
 
-                if (matches) {
+                if (token_id == llama_token_eos(model_tgt)) {
+                    has_eos = true;
+                }
+                ++n_predict;
+
+                if (accept) {
                     ++n_accept;
                     ++n_past_tgt;
                     ++n_past_dft;
@@ -245,17 +363,21 @@ int main(int argc, char ** argv) {
                     if (params.use_color) {
                         // Color token according to its origin sequence
                         printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
-                        fflush(stdout);
+                    } else {
+                        printf("%s", token_str.c_str());
                     }
+                    fflush(stdout);
                     continue;
+                } else {
+                    printf("%s", token_str.c_str());
+                    fflush(stdout);
+                    break;
                 }
             }
-            if (params.use_color) {
-                printf("%s", token_str.c_str());
-            }
-            fflush(stdout);
+        }
 
-            LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
+        {
+            LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
 
             // TODO: simplify
             {
@@ -275,21 +397,21 @@ int main(int argc, char ** argv) {
                 drafts[s].active = false;
                 drafts[s].tokens.clear();
                 drafts[s].i_batch_tgt.clear();
+                drafts[s].dists.clear();
             }
             // note: will be erased after the speculation phase
-            drafts[0].tokens.push_back(id);
+            drafts[0].tokens.push_back(token_id);
+            drafts[0].dists.push_back(std::vector<llama_token_data>());
             drafts[0].i_batch_tgt.push_back(0);
 
             llama_batch_clear(batch_dft);
-            llama_batch_add  (batch_dft, id, n_past_dft, { 0 }, true);
+            llama_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true);
 
             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
             // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
-            llama_decode         (ctx_dft, batch_dft);
+            llama_decode(ctx_dft, batch_dft);
 
             ++n_past_dft;
-
-            break;
         }
 
         if (n_predict > params.n_predict || has_eos) {
@@ -334,12 +456,6 @@ int main(int argc, char ** argv) {
                             k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
                 }
 
-                if (cur_p[0].p < p_accept) {
-                    LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
-                    drafts[s].drafting = false;
-                    continue;
-                }
-
                 std::vector<int> sa(1, s);
 
                 // attempt to split the branch if the probability is high enough
@@ -367,6 +483,7 @@ int main(int argc, char ** argv) {
                         drafts[n_seq_cur].skip     = true;
 
                         drafts[n_seq_cur].tokens      = drafts[s].tokens;
+                        drafts[n_seq_cur].dists       = drafts[s].dists;
                         drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
                         drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
 
@@ -389,6 +506,8 @@ int main(int argc, char ** argv) {
                     llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
 
                     drafts[s].tokens.push_back(id);
+                    // save cur_p.data into drafts[s].dists
+                    drafts[s].dists.push_back(cur_p);
 
                     // add unique drafted tokens to the target batch
                     drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
@@ -440,6 +559,7 @@ int main(int argc, char ** argv) {
             }
 
             drafts[s].tokens.erase(drafts[s].tokens.begin());
+            drafts[s].dists.erase(drafts[s].dists.begin());
         }
     }