#include <cstring>
#include <ctime>
#include <cfloat>
+#include <cmath>
#include <numeric>
#include <random>
#include <unordered_map>
-static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
-#if 1
- probs.resize(cur_p->size);
- for (size_t i = 0; i < cur_p->size; ++i) {
- probs[i] = cur_p->data[i].p;
- }
-
- std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
-#else
- // avoid the copy with a custom iterator
+static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
+ // iterator for the probabilities
+#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
+#endif
struct probs_iterator {
typedef std::input_iterator_tag iterator_category;
typedef float value_type;
typedef float * pointer;
typedef float & reference;
- typedef size_t difference_type;
+ typedef ptrdiff_t difference_type;
- const llama_token_data_array * data;
- size_t i;
+ const llama_token_data * data;
- bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
- bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
- float operator*() const { return data->data[i].p; }
- probs_iterator & operator++() { ++i; return *this; }
- probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
+ const float & operator*() const { return data->p; }
+ probs_iterator & operator++() { ++data; return *this; }
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
};
- #pragma GCC diagnostic pop
-
- std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
- GGML_UNUSED(probs);
+#ifdef __GNUC__
+ #pragma GCC diagnostic pop
#endif
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
+
return dist(rng);
}
+/*
static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size);
float sum = 0.f;
array[i] = logf(array[i] / sum);
}
}
+*/
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0);
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ llama_token_data_array cur_p = {
+ /* .data = */ cur.data(),
+ /* .size = */ cur.size(),
+ /* .selected = */ -1,
+ /* .sorted = */ false,
+ };
llama_sampler_apply(smpl, &cur_p);
- return cur_p.data[cur_p.selected].id;
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+ auto token = cur_p.data[cur_p.selected].id;
+
+ llama_sampler_accept(smpl, token);
+
+ return token;
}
// sampler chain
-static struct llama_sampler_i llama_sampler_chain_i = {
- /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
- auto * chain = (llama_sampler_chain *) smpl->ctx;
+static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
+ return "chain";
+}
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
+static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
- for (auto * smpl : chain->samplers) {
- llama_sampler_accept(smpl, token);
- }
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
- chain->n_sample++;
- },
- /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
- auto * chain = (llama_sampler_chain *) smpl->ctx;
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_accept(smpl, token);
+ }
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
+ chain->n_sample++;
+}
- for (auto * smpl : chain->samplers) {
- llama_sampler_apply(smpl, cur_p);
- }
- },
- /* .reset = */ [](struct llama_sampler * smpl) {
- auto * chain = (llama_sampler_chain *) smpl->ctx;
+static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
- for (auto * smpl : chain->samplers) {
- llama_sampler_reset(smpl);
- }
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
- chain->t_sample_us = 0;
- chain->n_sample = 0;
- },
- /* .clone = */ [](const struct llama_sampler * smpl) {
- const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_apply(smpl, cur_p);
+ }
+}
- auto * result = llama_sampler_chain_init(chain_src->params);
+static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
- for (auto * smpl : chain_src->samplers) {
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
- }
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_reset(smpl);
+ }
- return result;
- },
- /* .free = */ [](struct llama_sampler * smpl) {
- auto * chain = (llama_sampler_chain *) smpl->ctx;
+ chain->t_sample_us = 0;
+ chain->n_sample = 0;
+}
- for (auto * smpl : chain->samplers) {
- llama_sampler_free(smpl);
- }
+static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
+
+ auto * result = llama_sampler_chain_init(chain_src->params);
+
+ for (auto * smpl : chain_src->samplers) {
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+ }
+
+ return result;
+}
+
+static void llama_sampler_chain_free(struct llama_sampler * smpl) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
- delete chain;
- },
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_free(smpl);
+ }
+
+ delete chain;
+}
+
+static struct llama_sampler_i llama_sampler_chain_i = {
+ /* .name = */ llama_sampler_chain_name,
+ /* .accept = */ llama_sampler_chain_accept,
+ /* .apply = */ llama_sampler_chain_apply,
+ /* .reset = */ llama_sampler_chain_reset,
+ /* .clone = */ llama_sampler_chain_clone,
+ /* .free = */ llama_sampler_chain_free,
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
const uint32_t seed;
std::mt19937 rng;
-
- std::vector<float> probs; // work array
};
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
- cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .rng = */ std::mt19937(seed),
- /* .probs = */ {},
},
};
}
float mu;
std::mt19937 rng;
-
- std::vector<float> probs;
};
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
llama_sampler_softmax_impl(cur_p);
- const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
- /* .probs = */ {},
},
};
}
float mu;
std::mt19937 rng;
-
- std::vector<float> probs;
};
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
// Normalize the probabilities of the remaining words
llama_sampler_softmax_impl(cur_p);
- const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
- /* .probs = */ {},
},
};
}
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+ if (ctx->logit_bias.empty()) {
+ return;
+ }
+
ctx->to_search.clear();
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
}
}
+ if (ctx->to_search.empty()) {
+ return;
+ }
+
// search for the remaining candidates that were not found in the previous step
for (size_t i = 0; i < cur_p->size; ++i) {
for (const auto & lb : ctx->to_search) {