// helpers
-llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
- return &gsmpl->cur_p;
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
+ auto * res = &gsmpl->cur_p;
+
+ if (do_sort && !res->sorted) {
+ // remember the selected token before sorting
+ const llama_token id = res->data[res->selected].id;
+
+ std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
+ return a.p > b.p;
+ });
+
+ // restore the selected token after sorting
+ for (size_t i = 0; i < res->size; ++i) {
+ if (res->data[i].id == id) {
+ res->selected = i;
+ break;
+ }
+ }
+
+ res->sorted = true;
+ }
+
+ return res;
}
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
// helpers
// access the internal list of current candidate tokens
-llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
+// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
+// the .sorted flag of the result indicates whether the returned candidates are sorted
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
// get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl);
common_sampler_sample(smpl, ctx_dft, 0, true);
- const auto * cur_p = common_sampler_get_candidates(smpl);
+ const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
- auto & dist_tgt = *common_sampler_get_candidates(smpl);
+ auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
float p_tgt = 0.0f;
float p_dft = 0.0f;
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
- const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
+ const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
llama_token_data * data;
size_t size;
int64_t selected; // this is the index in the data array (i.e. not the token id)
- bool sorted;
+ bool sorted; // note: do not assume the data is sorted - always check this flag
} llama_token_data_array;
typedef bool (*llama_progress_callback)(float progress, void * user_data);
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
- /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
- /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
- DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
- "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
-
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
std::vector<T> data;
};
+// writes result in res, does not mutate cur
+static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+ return a.logit > b.logit;
+ };
+
+ constexpr int nbuckets = 128;
+ constexpr float bucket_low = -10.0f;
+ constexpr float bucket_high = 10.0f;
+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
+ constexpr float bucket_inter = -bucket_low * bucket_scale;
+
+ std::vector<int> bucket_idx;
+ std::vector<int> histo(nbuckets, 0);
+
+ std::vector<llama_token_data*> bucket_ptrs;
+
+ bucket_idx.reserve(cur.size);
+
+ for (int i = 0; i < (int)cur.size; ++i) {
+ const float val = cur.data[i].logit;
+ int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
+ ib = std::max(0, std::min(nbuckets - 1, ib));
+ bucket_idx.push_back(ib);
+ ++histo[ib];
+ }
+ int nhave = 0;
+ int ib = nbuckets - 1;
+ for ( ; ib >= 0; --ib) {
+ nhave += histo[ib];
+ if (nhave >= npartial) {
+ break;
+ }
+ }
+ res.resize(nhave);
+ auto * ptr = res.data();
+ bucket_ptrs.reserve(nbuckets - ib);
+ for (int j = nbuckets - 1; j >= ib; --j) {
+ bucket_ptrs.push_back(ptr);
+ ptr += histo[j];
+ }
+ for (int i = 0; i < (int)cur.size; ++i) {
+ int j = bucket_idx[i];
+ if (j >= ib) {
+ *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
+ }
+ }
+
+ ptr = res.data();
+ int ndone = 0;
+ for (int j = nbuckets - 1; j > ib; --j) {
+ std::sort(ptr, ptr + histo[j], comp);
+ ptr += histo[j];
+ ndone += histo[j];
+ }
+ std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
+}
+
+// reduces the size of cur_p to npartial, keeping only the top npartial elements
+static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+ return a.logit > b.logit;
+ };
+
+ if (npartial <= 128) {
+ std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
+
+ cur_p->size = npartial;
+ cur_p->sorted = true;
+
+ return;
+ }
+
+ std::vector<llama_token_data> tmp;
+
+ llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
+
+ std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
+
+ cur_p->size = npartial;
+ cur_p->sorted = true;
+}
+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
// iterator for the probabilities
#ifdef __GNUC__
}
}
-static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
+static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
GGML_ASSERT(cur_p->size > 0);
- // Sort the logits in descending order
- if (!cur_p->sorted) {
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
- return a.logit > b.logit;
- });
- cur_p->sorted = true;
+ // Sort the logits in descending order if requested
+ if (do_sort && !cur_p->sorted) {
+ llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
}
float max_l = cur_p->data[0].logit;
+ if (!cur_p->sorted) {
+ for (size_t i = 1; i < cur_p->size; ++i) {
+ max_l = std::max(max_l, cur_p->data[i].logit);
+ }
+ }
+
float cum_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
}
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
- // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
// if (k >= (int32_t)cur_p->size) {
// return;
// }
// Sort scores in descending order
if (!cur_p->sorted) {
- auto comp = [](const llama_token_data & a, const llama_token_data & b) {
- return a.logit > b.logit;
- };
- if (k <= 128) {
- std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
- } else {
- constexpr int nbuckets = 128;
- constexpr float bucket_low = -10.0f;
- constexpr float bucket_high = 10.0f;
- constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
- constexpr float bucket_inter = -bucket_low * bucket_scale;
-
- std::vector<int> bucket_idx(cur_p->size);
- std::vector<int> histo(nbuckets, 0);
-
- for (int i = 0; i < (int)cur_p->size; ++i) {
- const float val = cur_p->data[i].logit;
- int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
- ib = std::max(0, std::min(nbuckets - 1, ib));
- bucket_idx[i] = ib;
- ++histo[ib];
- }
- int nhave = 0;
- int ib = nbuckets - 1;
- for ( ; ib >= 0; --ib) {
- nhave += histo[ib];
- if (nhave >= k) {
- break;
- }
- }
- std::vector<llama_token_data> tmp_tokens(nhave);
- auto * ptr = tmp_tokens.data();
- std::vector<llama_token_data*> bucket_ptrs;
- bucket_ptrs.reserve(nbuckets - ib);
- for (int j = nbuckets - 1; j >= ib; --j) {
- bucket_ptrs.push_back(ptr);
- ptr += histo[j];
- }
- for (int i = 0; i < (int)cur_p->size; ++i) {
- int j = bucket_idx[i];
- if (j >= ib) {
- *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
- }
- }
-
- ptr = tmp_tokens.data();
- int ndone = 0;
- for (int j = nbuckets - 1; j > ib; --j) {
- std::sort(ptr, ptr + histo[j], comp);
- ptr += histo[j];
- ndone += histo[j];
- }
- std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
-
- std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
-
- }
- cur_p->sorted = true;
+ llama_token_data_array_partial_sort_inplace(cur_p, k);
}
cur_p->size = k;
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
- llama_sampler_softmax_impl(cur_p);
+ // sorting is not necessary here
+ llama_sampler_softmax_impl(cur_p, false);
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
}
);
}
-// softmax
-
-static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
- return "softmax";
-}
-
-static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
- llama_sampler_softmax_impl(cur_p);
-}
-
-static struct llama_sampler_i llama_sampler_softmax_i = {
- /* .name = */ llama_sampler_softmax_name,
- /* .accept = */ nullptr,
- /* .apply = */ llama_sampler_softmax_apply,
- /* .reset = */ nullptr,
- /* .clone = */ nullptr,
- /* .free = */ nullptr,
-};
-
-struct llama_sampler * llama_sampler_init_softmax() {
- return llama_sampler_init(
- /* .iface = */ &llama_sampler_softmax_i,
- /* .ctx = */ nullptr
- );
-}
-
// top-k
struct llama_sampler_top_k {
}
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+ auto * ctx = (llama_sampler_top_k *) smpl->ctx;
llama_sampler_top_k_impl(cur_p, ctx->k);
}
struct llama_sampler_top_p {
const float p;
const size_t min_keep;
+
+ std::vector<llama_token_data> buf_sort;
};
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
}
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+ auto * ctx = (llama_sampler_top_p *) smpl->ctx;
if (ctx->p >= 1.0f) {
return;
}
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, false);
+
+ size_t k = cur_p->size;
+ auto * pdata = cur_p->data;
+
+ auto & buf_sort = ctx->buf_sort;
+
+ // if not sorted, try adaptive top-k sorting
+ if (!cur_p->sorted && cur_p->size > 1024) {
+ k = std::min<size_t>(256, cur_p->size);
+ llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
+ pdata = buf_sort.data();
+ } else if (!cur_p->sorted) {
+ // small candidates -> sort inplace
+ llama_token_data_array_partial_sort_inplace(cur_p, k);
+ }
// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = cur_p->size;
for (size_t i = 0; i < cur_p->size; ++i) {
- cum_sum += cur_p->data[i].p;
+ cum_sum += pdata[i].p;
// Check if the running sum is at least p or if we have kept at least min_keep tokens
// we set the last index to i+1 to indicate that the current iterate should be included in the set
last_idx = i + 1;
break;
}
+
+ // we exceeded the current top-k heuristic -> increase k and continue
+ if (!cur_p->sorted && i == k - 1) {
+ k = cur_p->size;
+ llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
+ pdata = buf_sort.data();
+ }
}
// Resize the output vector to keep only the top-p tokens
+ if (!cur_p->sorted) {
+ std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
+ cur_p->sorted = true;
+ }
+
cur_p->size = last_idx;
}
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
+ /* .buf_sort = */ {},
}
);
}
}
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+ auto * ctx = (llama_sampler_min_p *) smpl->ctx;
if (ctx->p <= 0.0f || !cur_p->size) {
return;
// if we have enough values the operation was a success
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
- memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+ std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
cur_p->size = filtered_tokens.size();
min_p_applied = true;
}
if (!min_p_applied) {
// Sort the logits in descending order
if (!cur_p->sorted) {
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
- return a.logit > b.logit;
- });
- cur_p->sorted = true;
+ llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
}
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
}
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+ auto * ctx = (llama_sampler_typical *) smpl->ctx;
// Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
}
// Compute the softmax of logits and calculate entropy
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
float entropy = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
}
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+ auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
if (ctx->delta > 0) {
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
const float max_temp = ctx->temp + ctx->delta;
// Calculate maximum possible entropy
float max_entropy = -logf(1.0f / cur_p->size);
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
const uint32_t seed;
uint32_t seed_cur;
- std::mt19937 rng;
+ std::mt19937 rng;
};
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
float chance = distribution(ctx->rng);
- if (chance > ctx->probability) return;
+ if (chance > ctx->probability) {
+ return;
+ }
- // in case it's not sorted/recalculated yet
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
int pos_last = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p >= ctx->threshold) {
pos_last = i;
- } else break;
+ } else {
+ break;
+ }
}
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
float mu;
- std::mt19937 rng;
+ std::mt19937 rng;
};
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
- llama_sampler_softmax_impl(cur_p);
+
+ llama_sampler_softmax_impl(cur_p, true);
const int idx = llama_sample_dist(cur_p, ctx->rng);
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
// Truncate the words with surprise values greater than mu
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
}
// Normalize the probabilities of the remaining words
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
const int idx = llama_sample_dist(cur_p, ctx->rng);
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
}
trigger_pattern += ")[\\s\\S]*";
- auto trigger_pattern_c = trigger_pattern.c_str();
+ const auto * trigger_pattern_c = trigger_pattern.c_str();
trigger_patterns = &trigger_pattern_c;
num_trigger_patterns = 1;
}
}
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
+ auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
if (ctx->n <= 0.0f || cur_p->size <= 1) {
return;
}
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
- //apply mask
+ // apply mask
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit < max - (ctx->n * std)) {
cur_p->data[i].logit = -INFINITY;
}
}
- llama_sampler_softmax_impl(cur_p);
+
+ llama_sampler_softmax_impl(cur_p, true);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
{
const int last = last_n_repeat - 1;
- int rt = 0, lt = 0;
+
+ int rt = 0;
+ int lt = 0;
for (int k = 1; k < last_n_repeat; ++k) {
if (k > rt) {
/* .free = */ llama_sampler_dry_free,
};
-struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
- int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
+struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
- /* .total_context_size = */ context_size,
+ /* .total_context_size = */ n_ctx_train,
/* .dry_multiplier = */ dry_multiplier,
/* .dry_base = */ dry_base,
/* .dry_allowed_length = */ dry_allowed_length,
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;
- llama_sampler_softmax_impl(cur_p);
+ llama_sampler_softmax_impl(cur_p, true);
#if defined(GGML_DEBUG_SAMPLER_INFILL)
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
sampler_tester tester(n_vocab);
llama_token min_token_id = 0;
- const llama_token max_token_id = n_vocab-1;
+ const llama_token max_token_id = n_vocab - 1;
for (auto s : samplers_sequence) {
- switch (s){
+ switch (s) {
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
case 'y': GGML_ABORT("typical test not implemented");
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
}
GGML_ASSERT(size == expected_size);
- GGML_ASSERT(cur_p.data[0].id == max_token_id);
- GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
+ GGML_ASSERT(!cur_p.sorted || cur_p.data[0].id == max_token_id);
+ GGML_ASSERT(!cur_p.sorted || cur_p.data[expected_size-1].id == min_token_id);
} else if (s == 'm') {
- int expected_size = ceilf((1.0f-min_p) * n_vocab);
+ int expected_size = ceilf((1.0f - min_p) * n_vocab);
expected_size = std::max(expected_size, 1);
expected_size = std::min(expected_size, size);
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
GGML_ASSERT(size == expected_size);
- GGML_ASSERT(cur_p.data[0].id == max_token_id);
- GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
+ GGML_ASSERT(!cur_p.sorted || cur_p.data[0].id == max_token_id);
+ GGML_ASSERT(!cur_p.sorted || cur_p.data[expected_size-1].id == min_token_id);
} else {
GGML_ABORT("fatal error");
}
}
- printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
+ printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%5d top_p=%f min_p=%f\n",
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
}
int main(void) {
ggml_time_init();
- test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
- test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f);
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.0f, 1.0f}, 0.0f);
- test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
- test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f, 0.0f, 1.0f);
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.0f, 1.0f}, 0.0f, 0.0f, 1.0f);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
- test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
-
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.26f);
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.49f);
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.51f);
- test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f);
+
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f/1.0f, 0.2f/1.0f, 0.3f/1.0f, 0.4f/1.0f}, 0.00f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f/1.0f, 0.2f/1.0f, 0.3f/1.0f, 0.4f/1.0f}, 0.24f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.2f/0.9f, 0.3f/0.9f, 0.4f/0.9f}, 0.26f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.2f/0.9f, 0.3f/0.9f, 0.4f/0.9f}, 0.49f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.3f/0.7f, 0.4f/0.7f}, 0.51f);
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.3f/0.7f, 0.4f/0.7f}, 0.74f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.05f);
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, 50.0f, 0.0f, 0.0f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
- test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, 1.0f, 5.0f, 5.0f);
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
- test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
+ test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.109232f, 0.296923f}, 1.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
- test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
- test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
+ test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
- test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
+ test_sampler_queue(10000, "p", 10000, 0.0003f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);
return slot.has_next_token; // continue
}
- void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
+
if (post_sampling) {
- const auto * cur_p = common_sampler_get_candidates(slot.smpl);
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
const size_t max_probs = cur_p->size;
// set probability for sampled token
codes.push_back(new_token_id);
- const auto * cands = common_sampler_get_candidates(smpl[i]);
+ const auto * cands = common_sampler_get_candidates(smpl[i], false);
// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {