]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : default sampling changes + greedy update (#9897)
authorGeorgi Gerganov <redacted>
Mon, 21 Oct 2024 06:46:40 +0000 (09:46 +0300)
committerGitHub <redacted>
Mon, 21 Oct 2024 06:46:40 +0000 (09:46 +0300)
* llama : deprecate softmax sampler + fix dist sampler

ggml-ci

* tests : replace macros with functions

ggml-ci

* sampling : change temperature sampler logic

For t <= 0.0f, keep the max logit intact and set the rest to -inf

* cont : no need for special "greedy" logic

top-k == 1 is the same

* tests : init prob correctly

* llama : handle temp <= 0.0 in the temp_ext sampler too

ggml-ci

* cont : avoid extra loop in temperature sampler for sub-zero temp

ggml-ci

common/sampling.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/save-load-state/save-load-state.cpp
examples/speculative/speculative.cpp
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index 56cd0df6b81bc8adae41a1e1be13eb2ed2e24e6f..4ab3eface3384897b8af3d3971f7b858529d9a09 100644 (file)
@@ -171,60 +171,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 params.penalize_nl,
                 params.ignore_eos));
 
-    if (params.temp > 0.0f) {
-        if (params.mirostat == 0) {
-            for (const auto & cnstr : params.samplers) {
-                switch (cnstr) {
-                    case COMMON_SAMPLER_TYPE_TOP_K:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TOP_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_MIN_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_XTC:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TFS_Z:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TYPICAL_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TEMPERATURE:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
-                        break;
-                    case COMMON_SAMPLER_TYPE_INFILL:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
-                        break;
-                    default:
-                        GGML_ASSERT(false && "unknown sampler type");
-                }
+    if (params.mirostat == 0) {
+        for (const auto & cnstr : params.samplers) {
+            switch (cnstr) {
+                case COMMON_SAMPLER_TYPE_TOP_K:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
+                    break;
+                case COMMON_SAMPLER_TYPE_TOP_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_MIN_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_XTC:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                    break;
+                case COMMON_SAMPLER_TYPE_TFS_Z:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_TYPICAL_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_TEMPERATURE:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                    break;
+                case COMMON_SAMPLER_TYPE_INFILL:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
+                    break;
+                default:
+                    GGML_ASSERT(false && "unknown sampler type");
             }
-            llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
-            llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
-        } else if (params.mirostat == 1) {
-            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
-            llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
-        } else if (params.mirostat == 2) {
-            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
-            llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
-        } else {
-            GGML_ASSERT(false && "unknown mirostat version");
         }
+        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
+    } else if (params.mirostat == 1) {
+        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
+    } else if (params.mirostat == 2) {
+        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
     } else {
-        if (params.n_probs > 0) {
-            // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
-            // ref: https://github.com/ggerganov/llama.cpp/pull/9605
-            //
-            // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
-            // it is much faster, since we avoid sorting all tokens and should give a good approximation
-            llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
-            llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
-        }
-        llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
+        GGML_ASSERT(false && "unknown mirostat version");
     }
 
     return result;
index dcd9803a2adc27609191f30cfacfa4fbfcac93c5..65cd4eb515c7f4b698e0abdfadf6a8c326a944da 100644 (file)
@@ -46,7 +46,6 @@ actor LlamaContext {
         let sparams = llama_sampler_chain_default_params()
         self.sampling = llama_sampler_chain_init(sparams)
         llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
-        llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
         llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
     }
 
index 5f60a86cbc2d2c43a2a0cdba903e85e1b1c0b013..8c49a52a661249ef64146728d7673c27d096158a 100644 (file)
@@ -42,7 +42,6 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
     llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
 
     // tokenize prompt
@@ -107,7 +106,6 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
     llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
 
     printf("\nsecond run: %s", params.prompt.c_str());
@@ -171,7 +169,6 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
     llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
 
     printf("\nsingle seq run: %s", params.prompt.c_str());
index 8a64754151719673c1b2c1e7cc84012487b991b7..a40e755a26f843b5ce18d6b7f71320f7efc9b4c5 100644 (file)
@@ -185,8 +185,6 @@ int main(int argc, char ** argv) {
     // target model sampling context (reuse the llama_context's sampling instance)
     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
 
-    struct llama_sampler * softmax = llama_sampler_init_softmax();
-
     // draft sequence data
     std::vector<seq_draft> drafts(n_seq_dft);
 
@@ -629,7 +627,6 @@ int main(int argc, char ** argv) {
         common_sampler_free(drafts[s].smpl);
     }
 
-    llama_sampler_free(softmax);
     llama_batch_free(batch_dft);
 
     llama_free(ctx_tgt);
index 2558e9267905bca0a937ed797c9f0ca83db32292..d4059c8dd04311ecd999d679cc3c9db9d165c9ff 100644 (file)
@@ -217,6 +217,7 @@ extern "C" {
 
     typedef struct llama_token_data_array {
         // TODO: consider SoA
+        // NOTE: this pointer can be modified by the samplers
         llama_token_data * data;
         size_t size;
         int64_t selected; // this is the index in the data array (i.e. not the token id)
@@ -1069,12 +1070,13 @@ extern "C" {
 
     // available samplers:
 
-    LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);
-    LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed);
+    LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist  (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
-    LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void);
+    DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void),
+        "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
@@ -1090,6 +1092,8 @@ extern "C" {
 
     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
     LLAMA_API struct llama_sampler * llama_sampler_init_typical    (float   p, size_t min_keep);
+
+    /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
     LLAMA_API struct llama_sampler * llama_sampler_init_temp       (float   t);
 
     /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
index bd750c40ec65108197e8322533193fcbf45a23f1..d71516153cf827c936ec43c5de8a93c1b27bcb23 100644 (file)
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
 }
 */
 
+static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
+    if (temp <= 0.0f) {
+        // find the token with the highest logit and set the rest to -inf
+        size_t max_i = 0;
+        float  max_l = cur_p->data[0].logit;
+
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            if (cur_p->data[i    ].logit > max_l) {
+                cur_p->data[max_i].logit = -INFINITY;
+                max_i = i;
+                max_l = cur_p->data[i].logit;
+            } else {
+                cur_p->data[i].logit = -INFINITY;
+            }
+        }
+
+        return;
+    }
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].logit /= temp;
+    }
+}
+
 static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
     GGML_ASSERT(cur_p->size > 0);
 
@@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
 
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
+
+    llama_sampler_softmax_impl(cur_p);
+
     cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
 }
 
@@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
 
 static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     const auto * ctx = (llama_sampler_temp *) smpl->ctx;
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        cur_p->data[i].logit /= ctx->temp;
-    }
+
+    llama_sampler_temp_impl(cur_p, ctx->temp);
 }
 
 static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
     if (ctx->delta > 0) {
         const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
         const float max_temp = ctx->temp + ctx->delta;
+
         float exponent_val = ctx->exponent;
 
         // no need to do anything if there is only one (or zero) candidates
@@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
     #endif
 
         // Apply the dynamically calculated temperature scaling
-        for (size_t i = 0; i < cur_p->size; ++i) {
-            cur_p->data[i].logit /= dyn_temp;
-        }
+        llama_sampler_temp_impl(cur_p, dyn_temp);
 
         // Re-compute softmax probabilities after scaling logits with dynamic temperature
         const double max_l_double = cur_p->data[0].logit;
@@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
         }
     #endif
     } else {
-        for (size_t i = 0; i < cur_p->size; ++i) {
-            cur_p->data[i].logit /= ctx->temp;
-        }
+        llama_sampler_temp_impl(cur_p, ctx->temp);
     }
 }
 
index 1372bdf13f2f608529c1b0b67e3ffa1c9410a65c..05600e6f54e9094196501b043f89ae37db902cef 100644 (file)
@@ -18,203 +18,176 @@ static void dump(const llama_token_data_array * cur_p) {
 
 #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
 
-#define APPLY(__cnstr, __cur_p) do { \
-    auto * cnstr = (__cnstr); \
-    llama_sampler_apply(cnstr, (__cur_p)); \
-    llama_sampler_free(cnstr); \
-} while(0)
+struct sampler_tester {
+    sampler_tester(size_t n_vocab) {
+        cur.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+            const float logit = logf(token_id);
+            cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        }
 
-static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
-    const size_t n_vocab = probs.size();
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
+    }
 
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
+        cur.reserve(probs.size());
+        for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
+            const float logit = logf(probs[token_id]);
+            cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
+        }
+
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
     }
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_top_k(k), &cur_p);
-    DUMP(&cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
+    void apply(llama_sampler * sampler) {
+        llama_sampler_apply(sampler, &cur_p);
+        llama_sampler_free(sampler);
     }
-}
 
-static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+    void check() {
+        GGML_ASSERT(cur_p.size == probs_expected.size());
+        for (size_t i = 0; i < cur_p.size; i++) {
+            GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
+        }
+    }
+
+    llama_token_data_array cur_p;
+
+private:
+    const std::vector<float> probs_expected;
 
     std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+};
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
-    DUMP(&cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp(temp));
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
 }
 
-static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
-    const size_t n_vocab = probs.size();
+static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
-    DUMP(&cur_p);
+    tester.check();
+}
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_k(k));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
 }
 
-static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_p(p, 1));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+    tester.check();
 }
 
-static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
-    const size_t n_vocab = probs.size();
+static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_tail_free(z, 1));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
-    DUMP(&cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
-    }
+    tester.check();
 }
 
-static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_min_p(p, 1));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_typical(p, 1), &cur_p);
-    DUMP(&cur_p);
+    tester.check();
+}
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_typical(p, 1));
+    DUMP(&tester.cur_p);
+
+    tester.check();
 }
 
 static void test_penalties(
     const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
-    const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
+    const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
 ) {
-    GGML_ASSERT(probs.size() == expected_probs.size());
+    GGML_ASSERT(probs.size() == probs_expected.size());
 
-    const size_t n_vocab = probs.size();
-
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
-
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    sampler_tester tester(probs, probs_expected);
 
+    const size_t n_vocab = probs.size();
     auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
 
     for (size_t i = 0; i < last_tokens.size(); i++) {
         llama_sampler_accept(sampler, last_tokens[i]);
     }
 
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(sampler, &cur_p);
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
+    DUMP(&tester.cur_p);
+    tester.apply(sampler);
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+    tester.check();
 }
 
 static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
 ) {
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(token_id);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
-
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    sampler_tester tester(n_vocab);
 
           llama_token min_token_id = 0;
     const llama_token max_token_id = n_vocab-1;
 
     for (auto s : samplers_sequence) {
         switch (s){
-            case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
+            case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
             case 'f': GGML_ABORT("tail_free test not implemented");
             case 'y': GGML_ABORT("typical test not implemented");
-            case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
-            case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
+            case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
+            case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
             case 't': GGML_ABORT("temperature test not implemented");
             default : GGML_ABORT("Unknown sampler");
         }
 
-        APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
+        tester.apply(llama_sampler_init_dist(0));
+
+        auto & cur_p = tester.cur_p;
 
         const int size = cur_p.size;
 
@@ -307,21 +280,26 @@ static void test_perf() {
     BENCH(llama_sampler_init_tail_free(0.5f, 1),                data, 32);
     BENCH(llama_sampler_init_typical  (0.5f, 1),                data, 32);
     BENCH(llama_sampler_init_xtc      (1.0f, 0.1f, 1, 1),       data, 32);
-    BENCH(llama_sampler_init_softmax  (),                       data, 32);
 }
 
 int main(void) {
     ggml_time_init();
 
-    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
-    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
+
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
+
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
 
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
 
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);