static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
- // sorting is not necessary here
- llama_sampler_softmax_impl(cur_p, false);
+ // edge cases
+ if (cur_p->size == 0) {
+ cur_p->selected = -1;
+ return;
+ }
+
+ cur_p->selected = 0;
+
+ if (cur_p->size == 1) {
+ cur_p->data[0].p = 1.0f;
+ return;
+ }
+
+ // max logit for numerical stability
+ float max_l = cur_p->data[0].logit;
+ if (!cur_p->sorted) {
+ for (size_t i = 1; i < cur_p->size; ++i) {
+ max_l = std::max(max_l, cur_p->data[i].logit);
+ }
+ }
+
+ // apply softmax to obtain the probabilities
+ double sum_cum = 0.0f;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ float p = expf(cur_p->data[i].logit - max_l);
+ cur_p->data[i].p = p;
+ sum_cum += p;
+ }
+
+#if 1
+ // sample from the obtained probabilities and normalize the probs in a single pass
+ // this is ~3x faster on Mac with full gpt-oss vocab than the version below
+ //
+ std::uniform_real_distribution<double> dist(0.0f, 1.0f);
+ const double rnd = dist(ctx->rng);
+
+ double sum_run = 0.0f;
+ const double sum_tgt = sum_cum*rnd;
+
+ bool found = false;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (!found) {
+ // accumulate probs until we reach the target sum
+ sum_run += cur_p->data[i].p;
+ if (sum_run >= sum_tgt) {
+ cur_p->selected = i;
+ found = true;
+ }
+ }
+
+ // normalize probs
+ cur_p->data[i].p /= sum_cum;
+ }
+
+ // fallback to the last token (don't think this can happen)
+ assert(found);
+ if (!found) {
+ cur_p->selected = cur_p->size - 1;
+ }
+#else
+ // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= sum_cum;
+ }
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+#endif
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {