]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : optimize dist sampler (#15704)
authorGeorgi Gerganov <redacted>
Wed, 3 Sep 2025 15:16:26 +0000 (18:16 +0300)
committerGitHub <redacted>
Wed, 3 Sep 2025 15:16:26 +0000 (18:16 +0300)
ggml-ci

src/llama-sampling.cpp

index e8c0fc3418bf3b2d0a4e8e55f86a8a7e461e1dd0..2186f827bf54307731d1fbb57e9b38b380415f94 100644 (file)
@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
 
-    // sorting is not necessary here
-    llama_sampler_softmax_impl(cur_p, false);
+    // edge cases
+    if (cur_p->size == 0) {
+        cur_p->selected = -1;
+        return;
+    }
+
+    cur_p->selected = 0;
+
+    if (cur_p->size == 1) {
+        cur_p->data[0].p = 1.0f;
+        return;
+    }
+
+    // max logit for numerical stability
+    float max_l = cur_p->data[0].logit;
+    if (!cur_p->sorted) {
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            max_l = std::max(max_l, cur_p->data[i].logit);
+        }
+    }
+
+    // apply softmax to obtain the probabilities
+    double sum_cum = 0.0f;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float p = expf(cur_p->data[i].logit - max_l);
+        cur_p->data[i].p = p;
+        sum_cum += p;
+    }
+
+#if 1
+    // sample from the obtained probabilities and normalize the probs in a single pass
+    // this is ~3x faster on Mac with full gpt-oss vocab than the version below
+    //
+    std::uniform_real_distribution<double> dist(0.0f, 1.0f);
+    const double rnd = dist(ctx->rng);
+
+          double sum_run = 0.0f;
+    const double sum_tgt = sum_cum*rnd;
+
+    bool found = false;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (!found) {
+            // accumulate probs until we reach the target sum
+            sum_run += cur_p->data[i].p;
+            if (sum_run >= sum_tgt) {
+                cur_p->selected = i;
+                found = true;
+            }
+        }
+
+        // normalize probs
+        cur_p->data[i].p /= sum_cum;
+    }
+
+    // fallback to the last token (don't think this can happen)
+    assert(found);
+    if (!found) {
+        cur_p->selected = cur_p->size - 1;
+    }
+#else
+    // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= sum_cum;
+    }
 
     cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+#endif
 }
 
 static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {