return id;
}
+static llama_token_data_array llama_sample_probability_distribution_impl(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ const llama_sampling_params & params = ctx_sampling->params;
+
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+ const float penalty_repeat = params.penalty_repeat;
+ const float penalty_freq = params.penalty_freq;
+ const float penalty_present = params.penalty_present;
+ const bool penalize_nl = params.penalize_nl;
+
+ auto & prev = ctx_sampling->prev;
+ auto & cur = ctx_sampling->cur;
+
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
+
+ // Declare original_logits at the beginning of the function scope
+ std::vector<float> original_logits;
+
+ // apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
+ }
+
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
+ cur.clear();
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+
+ // apply penalties
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+
+ llama_sample_repetition_penalties(ctx_main, &cur_p,
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+
+ if (!penalize_nl) {
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
+ cur_p.data[idx].logit = nl_logit;
+ break;
+ }
+ }
+ }
+ }
+
+ // apply grammar checks
+ if (ctx_sampling->grammar != NULL) {
+ llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ }
+
+ llama_sample_softmax(ctx_main, &cur_p);
+ return cur_p;
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
+llama_token_data_array llama_sampling_probability_distribution(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+}
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
#include <cstdio>
#include <string>
#include <vector>
+#include <set>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
+ std::vector<std::vector<llama_token_data>> dists;
struct llama_sampling_context * ctx_sampling;
};
// max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel;
- // probability threshold for accepting a token from the draft model
- const float p_accept = params.p_accept;
-
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
+ if (params.seed == LLAMA_DEFAULT_SEED) {
+ params.seed = time(NULL);
+ }
+ std::default_random_engine rng(params.seed);
+ std::uniform_real_distribution<> u_dist;
+
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
- params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
+ if (params.sparams.temp == 0) {
+ params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
+ }
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
drafts[0].i_batch_tgt[0] = 0;
while (true) {
+ std::set<int> active_seqs = {};
+
// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}
+ active_seqs.insert(s);
const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
int i_dft = 0;
int s_keep = 0;
+ llama_token token_id;
+ std::string token_str;
+
+ // loop until we fail to accept a drafted token or we run out of drafted tokens
while (true) {
- LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
- // sample from the target model
- llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+ // check if the target token matches any of the drafts
+ // for stochastic sampling, attempt to match the token with the drafted tokens
+ {
+ bool accept = false;
+ if (params.sparams.temp > 0) {
+ // stochastic verification
+
+ llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+ float p_tgt = 0, p_dft = 0;
+
+ // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
+
+ while (active_seqs.size() > 0) {
+ // randomly select a sequence to verify from active sequences
+ std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1);
+ int s = *std::next(active_seqs.begin(), u_int_dist(rng));
+ if (i_dft >= (int) drafts[s].tokens.size()) {
+ drafts[s].active = false;
+ active_seqs.erase(s);
+ continue;
+ }
+ if (accept) {
+ // if we already accepted a token, we can skip the rest
+ if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
+ drafts[s].active = false;
+ active_seqs.erase(s);
+ }
+ continue;
+ }
+ LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
+ float r = u_dist(rng);
+ llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
+ // acquire the token probabilities assigned by the draft and target models
+ for (size_t i = 0; i < dist_tgt.size; i++) {
+ if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
+ p_tgt = dist_tgt.data[i].p;
+ }
+ if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
+ p_dft = dist_dft.data[i].p;
+ }
+ if (p_tgt && p_dft) {
+ break;
+ }
+ }
+ LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
+ if (r <= p_tgt / p_dft) {
+ s_keep = s;
+ accept = true;
+ token_id = drafts[s].tokens[i_dft];
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+
+ LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
+ break;
+ } else {
+ LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
+ drafts[s].active = false;
+
+ // calculate residual probability
+ GGML_ASSERT(dist_tgt.sorted);
+ GGML_ASSERT(dist_dft.sorted);
+ float sum_probs = 0.0f;
+
+ // sort dist by id
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
+ return a.id < b.id;
+ });
+ std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
+ return a.id < b.id;
+ });
+
+ for (size_t i = 0; i < dist_tgt.size; i++) {
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+ sum_probs += dist_tgt.data[i].p;
+ }
+ for (size_t i = 0; i < dist_tgt.size; i++) {
+ dist_tgt.data[i].p /= sum_probs;
+ }
- llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
+ // sort dist_tgt by p desc
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
+ return a.p > b.p;
+ });
+ }
- //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
+ active_seqs.erase(s);
+ for(int i = 0; i < n_seq_dft; i++) {
+ if (i == s) {
+ continue;
+ }
+ if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
+ // synchronize active status for sequences with the same drafted token
+ drafts[i].active = drafts[i].active && accept;
+ if (!drafts[i].active) {
+ active_seqs.erase(s);
+ }
+ }
+ }
+ }
- const std::string token_str = llama_token_to_piece(ctx_tgt, id);
+ if (!accept) {
+ // all drafted tokens were rejected
+ // sample from the target model
+ LOG("all drafted tokens were rejected, sampling from residual distribution\n");
+ token_id = llama_sample_token(ctx_tgt, &dist_tgt);
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
+ }
- if (!params.use_color) {
- printf("%s", token_str.c_str());
- }
+ } else {
+ // greedy verification
- if (id == llama_token_eos(model_tgt)) {
- has_eos = true;
- }
+ // sample from the target model
+ LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
+ token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
- ++n_predict;
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
- // check if the target token matches any of the drafts
- {
- bool matches = false;
+ //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
- for (int s = 0; s < n_seq_dft; ++s) {
- if (!drafts[s].active) {
- continue;
- }
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
- if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
- LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
+ for (int s = 0; s < n_seq_dft; ++s) {
+ if (!drafts[s].active) {
+ continue;
+ }
- s_keep = s;
- matches = true;
- } else {
- drafts[s].active = false;
+ if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
+ LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
+
+ s_keep = s;
+ accept = true;
+ } else {
+ drafts[s].active = false;
+ }
}
}
- if (matches) {
+ if (token_id == llama_token_eos(model_tgt)) {
+ has_eos = true;
+ }
+ ++n_predict;
+
+ if (accept) {
++n_accept;
++n_past_tgt;
++n_past_dft;
if (params.use_color) {
// Color token according to its origin sequence
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
- fflush(stdout);
+ } else {
+ printf("%s", token_str.c_str());
}
+ fflush(stdout);
continue;
+ } else {
+ printf("%s", token_str.c_str());
+ fflush(stdout);
+ break;
}
}
- if (params.use_color) {
- printf("%s", token_str.c_str());
- }
- fflush(stdout);
+ }
- LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
+ {
+ LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
+ drafts[s].dists.clear();
}
// note: will be erased after the speculation phase
- drafts[0].tokens.push_back(id);
+ drafts[0].tokens.push_back(token_id);
+ drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
llama_batch_clear(batch_dft);
- llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
+ llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
- llama_decode (ctx_dft, batch_dft);
+ llama_decode(ctx_dft, batch_dft);
++n_past_dft;
-
- break;
}
if (n_predict > params.n_predict || has_eos) {
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
}
- if (cur_p[0].p < p_accept) {
- LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
- drafts[s].drafting = false;
- continue;
- }
-
std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
+ drafts[n_seq_cur].dists = drafts[s].dists;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id);
+ // save cur_p.data into drafts[s].dists
+ drafts[s].dists.push_back(cur_p);
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
}
drafts[s].tokens.erase(drafts[s].tokens.begin());
+ drafts[s].dists.erase(drafts[s].dists.begin());
}
}