]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : new sampling algorithms (#1126)
authorIvan Stepanov <redacted>
Sat, 29 Apr 2023 05:34:41 +0000 (08:34 +0300)
committerGitHub <redacted>
Sat, 29 Apr 2023 05:34:41 +0000 (08:34 +0300)
* Sample interface, new samplers.

New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat

Ignore EOS fix: -inf should be used.

* mirostat

* Added --logit-bias and --no-penalize-nl, removed std::span

* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)

Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)

* Save and load example adjust

* Tests

* Windows build fix

* Windows test fix

examples/common.cpp
examples/common.h
examples/main/main.cpp
examples/save-load-state/save-load-state.cpp
llama.cpp
llama.h
tests/CMakeLists.txt
tests/test-sampling.cpp [new file with mode: 0644]

index 9f10dc268558bc0f572bc0b91d680c73cfd40b32..6c712c713db9b315c9b997c51b648044ef707789 100644 (file)
@@ -6,6 +6,8 @@
 #include <string>
 #include <iterator>
 #include <algorithm>
+#include <sstream>
+#include <iostream>
 
 #if defined (_WIN32)
 #include <fcntl.h>
@@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.temp = std::stof(argv[i]);
+        } else if (arg == "--tfs") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.tfs_z = std::stof(argv[i]);
+        } else if (arg == "--typical") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.typical_p = std::stof(argv[i]);
         } else if (arg == "--repeat_last_n") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.repeat_penalty = std::stof(argv[i]);
+        } else if (arg == "--frequency_penalty") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.frequency_penalty = std::stof(argv[i]);
+        } else if (arg == "--presence_penalty") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.presence_penalty = std::stof(argv[i]);
+        } else if (arg == "--mirostat") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.mirostat = std::stoi(argv[i]);
+        } else if (arg == "--mirostat_lr") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.mirostat_eta = std::stof(argv[i]);
+        } else if (arg == "--mirostat_ent") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.mirostat_tau = std::stof(argv[i]);
         } else if (arg == "-b" || arg == "--batch_size") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "--perplexity") {
             params.perplexity = true;
         } else if (arg == "--ignore-eos") {
-            params.ignore_eos = true;
+            params.logit_bias[llama_token_eos()] = -INFINITY;
+        } else if (arg == "--no-penalize-nl") {
+            params.penalize_nl = false;
+        } else if (arg == "-l" || arg == "--logit-bias") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::stringstream ss(argv[i]);
+            llama_token key;
+            char sign;
+            std::string value_str;
+            try {
+                if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+                    params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+                } else {
+                    throw std::exception();
+                }
+            } catch (const std::exception &e) {
+                invalid_param = true;
+                break;
+            }
         } else if (arg == "--n_parts") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -f FNAME, --file FNAME\n");
     fprintf(stderr, "                        prompt file to start generation.\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
-    fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
-    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", (double)params.top_p);
-    fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
-    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
+    fprintf(stderr, "  --top_k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
+    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
+    fprintf(stderr, "  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
+    fprintf(stderr, "  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
+    fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+    fprintf(stderr, "  --presence_penalty N  repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
+    fprintf(stderr, "  --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+    fprintf(stderr, "  --mirostat N          use Mirostat sampling.\n");
+    fprintf(stderr, "                        Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
+    fprintf(stderr, "                        (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
+    fprintf(stderr, "  --mirostat_lr N       Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
+    fprintf(stderr, "  --mirostat_ent N      Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+    fprintf(stderr, "  -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
+    fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n");
+    fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
+    fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx);
-    fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
+    fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
+    fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n");
     fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
index 9d3697d793eff644824c3bdccde9a1d618944835..14e6b1ba7c113563d7eeacee752308c978cbfce3 100644 (file)
@@ -8,6 +8,7 @@
 #include <vector>
 #include <random>
 #include <thread>
+#include <unordered_map>
 
 //
 // CLI argument parsing
@@ -17,17 +18,25 @@ struct gpt_params {
     int32_t seed          = -1;   // RNG seed
     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict     = 128;  // new tokens to predict
-    int32_t repeat_last_n = 64;   // last n tokens to penalize
     int32_t n_parts       = -1;   // amount of model parts (-1 = determine from model dimensions)
     int32_t n_ctx         = 512;  // context size
     int32_t n_batch       = 512;  // batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep        = 0;    // number of tokens to keep from initial prompt
 
     // sampling parameters
-    int32_t top_k = 40;
-    float   top_p = 0.95f;
-    float   temp  = 0.80f;
-    float   repeat_penalty  = 1.10f;
+    std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+    int32_t top_k = 0;              // <= 0 to use vocab size
+    float   top_p = 1.0f;           // 1.0 = disabled
+    float   tfs_z = 1.0f;           // 1.0 = disabled
+    float   typical_p = 1.0f;       // 1.0 = disabled
+    float   temp = 1.0f;            // 1.0 = disabled
+    float   repeat_penalty  = 1.0f; // 1.0 = disabled
+    int32_t repeat_last_n = -1;     // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float   frequency_penalty = 0.0f; // 0.0 = disabled
+    float   presence_penalty = 0.0f;  // 0.0 = disabled
+    int     mirostat = 0;           // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float   mirostat_tau = 5.0f;    // target entropy
+    float   mirostat_eta = 0.1f;    // learning rate
 
     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt = "";
@@ -47,7 +56,7 @@ struct gpt_params {
     bool interactive_first = false; // wait for user input immediately
 
     bool instruct          = false; // instruction mode (used for Alpaca models)
-    bool ignore_eos        = false; // do not stop generating after eos
+    bool penalize_nl       = true;  // consider newlines as a repeatable token
     bool perplexity        = false; // compute perplexity over the prompt
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
index fda65574fad7af147cb610544e0d46e7f3fae5f1..674920b8a04c53cdb9955c417b7051dee15a76d8 100644 (file)
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
         }
     }
-    fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
-        params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
+    fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
+            params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
     fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     fprintf(stderr, "\n\n");
 
@@ -387,10 +387,19 @@ int main(int argc, char ** argv) {
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
             // out of user input, sample next token
-            const int32_t top_k          = params.top_k;
-            const float   top_p          = params.top_p;
             const float   temp           = params.temp;
+            const int32_t top_k          = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+            const float   top_p          = params.top_p;
+            const float   tfs_z          = params.tfs_z;
+            const float   typical_p      = params.typical_p;
+            const int32_t repeat_last_n  = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
             const float   repeat_penalty = params.repeat_penalty;
+            const float   alpha_presence = params.presence_penalty;
+            const float   alpha_frequency = params.frequency_penalty;
+            const int     mirostat       = params.mirostat;
+            const float   mirostat_tau   = params.mirostat_tau;
+            const float   mirostat_eta   = params.mirostat_eta;
+            const bool    penalize_nl   = params.penalize_nl;
 
             // optionally save the session on first sample (for faster prompt loading next time)
             if (!path_session.empty() && need_to_save_session) {
@@ -402,14 +411,58 @@ int main(int argc, char ** argv) {
 
             {
                 auto logits = llama_get_logits(ctx);
+                auto n_vocab = llama_n_vocab(ctx);
 
-                if (params.ignore_eos) {
-                    logits[llama_token_eos()] = 0;
+                // Apply params.logit_bias map
+                for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+                    logits[it->first] += it->second;
                 }
 
-                id = llama_sample_top_p_top_k(ctx,
-                        last_n_tokens.data() + n_ctx - params.repeat_last_n,
-                        params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
+                std::vector<llama_token_data> candidates;
+                candidates.reserve(n_vocab);
+                for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+                    candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+                }
+
+                llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+                // Apply penalties
+                float nl_logit = logits[llama_token_nl()];
+                auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+                llama_sample_repetition_penalty(ctx, &candidates_p,
+                    last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+                    last_n_repeat, repeat_penalty);
+                llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
+                    last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+                    last_n_repeat, alpha_frequency, alpha_presence);
+                if (!penalize_nl) {
+                    logits[llama_token_nl()] = nl_logit;
+                }
+
+                if (temp <= 0) {
+                    // Greedy sampling
+                    id = llama_sample_token_greedy(ctx, &candidates_p);
+                } else {
+                    if (mirostat == 1) {
+                        static float mirostat_mu = 2.0f * mirostat_tau;
+                        const int mirostat_m = 100;
+                        llama_sample_temperature(ctx, &candidates_p, temp);
+                        id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+                    } else if (mirostat == 2) {
+                        static float mirostat_mu = 2.0f * mirostat_tau;
+                        llama_sample_temperature(ctx, &candidates_p, temp);
+                        id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+                    } else {
+                        // Temperature sampling
+                        llama_sample_top_k(ctx, &candidates_p, top_k);
+                        llama_sample_tail_free(ctx, &candidates_p, tfs_z);
+                        llama_sample_typical(ctx, &candidates_p, typical_p);
+                        llama_sample_top_p(ctx, &candidates_p, top_p);
+                        llama_sample_temperature(ctx, &candidates_p, temp);
+                        id = llama_sample_token(ctx, &candidates_p);
+                    }
+                }
+                // printf("`%d`", candidates_p.size);
 
                 last_n_tokens.erase(last_n_tokens.begin());
                 last_n_tokens.push_back(id);
index 39aa7f82cae5c6ed6d3cd936455dde154d2aab8f..07dfa2c74ed07c39152b0f710c308c64b20574a5 100644 (file)
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
     // first run
     printf("\n%s", params.prompt.c_str());
     for (auto i = 0; i < params.n_predict; i++) {
-        auto next_token = llama_sample_top_p_top_k(
-            ctx,
-            &last_n_tokens_data.back() - params.repeat_last_n,
-            params.repeat_last_n,
-            40,
-            1.0,
-            1.0,
-            1.1);
+        auto logits = llama_get_logits(ctx);
+        auto n_vocab = llama_n_vocab(ctx);
+        std::vector<llama_token_data> candidates;
+        candidates.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        }
+        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+        auto next_token = llama_sample_token(ctx, &candidates_p);
         auto next_token_str = llama_token_to_str(ctx, next_token);
         last_n_tokens_data.push_back(next_token);
         printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
 
     // second run
     for (auto i = 0; i < params.n_predict; i++) {
-        auto next_token = llama_sample_top_p_top_k(
-            ctx2,
-            &last_n_tokens_data.back() - params.repeat_last_n,
-            params.repeat_last_n,
-            40,
-            1.0,
-            1.0,
-            1.1);
+        auto logits = llama_get_logits(ctx2);
+        auto n_vocab = llama_n_vocab(ctx2);
+        std::vector<llama_token_data> candidates;
+        candidates.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        }
+        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+        auto next_token = llama_sample_token(ctx2, &candidates_p);
         auto next_token_str = llama_token_to_str(ctx2, next_token);
         last_n_tokens_data.push_back(next_token);
         printf("%s", next_token_str);
index 4699e5cf1de7c4299ce1ddd64c317727b6de5166..1032fb9fa9363c99908a976a39acb73dabe01c70 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -28,6 +28,7 @@
 #include <atomic>
 #include <mutex>
 #include <sstream>
+#include <numeric>
 
 #define LLAMA_USE_SCRATCH
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
@@ -1475,109 +1476,402 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
 // sampling
 //
 
-static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
-    // find the top k tokens
-    std::partial_sort(
-            logits_id.begin(),
-            logits_id.begin() + top_k, logits_id.end(),
-            [](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
-        return a.first > b.first;
-    });
+void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
+    assert(candidates->size > 0);
+
+    const int64_t t_start_sample_us = ggml_time_us();
 
-    logits_id.resize(top_k);
+    // Sort the logits in descending order
+    if (!candidates->sorted) {
+        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+            return a.logit > b.logit;
+        });
+        candidates->sorted = true;
+    }
+
+    float max_l = candidates->data[0].logit;
+    float cum_sum = 0.0f;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        float p = expf(candidates->data[i].logit - max_l);
+        candidates->data[i].p = p;
+        cum_sum += p;
+    }
+    for (size_t i = 0; i < candidates->size; ++i) {
+        candidates->data[i].p /= cum_sum;
+    }
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
 }
 
-static llama_vocab::id llama_sample_top_p_top_k(
-        llama_context & lctx,
-        const std::vector<llama_vocab::id> & last_n_tokens,
-        int top_k,
-        float top_p,
-        float temp,
-        float repeat_penalty) {
-    auto & rng = lctx.rng;
-
-    const int n_logits = lctx.model.hparams.n_vocab;
-
-    const auto & logits = lctx.logits;
-    const auto * plogits = logits.data() + logits.size() - n_logits;
-
-    if (temp <= 0) {
-        // select the token with the highest logit directly
-        float max_logit = plogits[0];
-        llama_vocab::id max_id = 0;
-
-        for (int i = 1; i < n_logits; ++i) {
-            if (plogits[i] > max_logit) {
-                max_logit = plogits[i];
-                max_id = i;
-            }
+void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    k = std::max(k, (int) min_keep);
+    k = std::min(k, (int) candidates->size);
+
+    // Sort scores in descending order
+    if (!candidates->sorted) {
+        auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+            return a.logit > b.logit;
+        };
+        if (k == (int) candidates->size) {
+            std::sort(candidates->data, candidates->data + candidates->size, comp);
+        } else {
+            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
         }
-        return max_id;
+        candidates->sorted = true;
     }
+    candidates->size = k;
 
-    std::vector<std::pair<float, llama_vocab::id>> logits_id;
-    logits_id.reserve(n_logits);
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
 
-    {
-        const float scale = 1.0f/temp;
-        for (int i = 0; i < n_logits; ++i) {
-            // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
-            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
-            if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
-                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
-                if (plogits[i] < 0.0f) {
-                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
-                } else {
-                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
-                }
-            } else {
-                logits_id.push_back(std::make_pair(plogits[i]*scale, i));
-            }
+void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    if (p >= 1.0f) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax(ctx, candidates);
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = candidates->size;
+
+    for (size_t i = 0; i < candidates->size; ++i) {
+        cum_sum += candidates->data[i].p;
+
+        // Check if the running sum is greater than p or if we have kept at least min_keep tokens
+        if (cum_sum > p && i >= min_keep) {
+            last_idx = i;
+            break;
         }
     }
 
-    sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits);
+    // Resize the output vector to keep only the top-p tokens
+    candidates->size = last_idx;
 
-    // compute probs for the top k tokens
-    std::vector<float> probs;
-    probs.reserve(logits_id.size());
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
 
-    float maxl = logits_id[0].first;
-    double sum = 0.0;
-    for (const auto & kv : logits_id) {
-        const float p = expf(kv.first - maxl);
-        probs.push_back(p);
-        sum += p;
+void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
+    if (z >= 1.0f || candidates->size <= 2) {
+        return;
     }
 
-    // normalize the probs
-    for (auto & p : probs) {
-        p /= sum;
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax(nullptr, candidates);
+
+    // Compute the first and second derivatives
+    std::vector<float> first_derivatives(candidates->size - 1);
+    std::vector<float> second_derivatives(candidates->size - 2);
+
+    for (size_t i = 0; i < first_derivatives.size(); ++i) {
+        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
+    }
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
     }
 
-    if (top_p < 1.0) {
-        double cumsum = 0.0;
-        for (int i = 0; i < (int) probs.size(); i++) {
-            cumsum += probs[i];
-            if (cumsum >= top_p) {
-                probs.resize(i + 1);
-                logits_id.resize(i + 1);
-                break;
-            }
+    // Calculate absolute value of second derivatives
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = abs(second_derivatives[i]);
+    }
+
+    // Normalize the second derivatives
+    float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
+    for (float & value : second_derivatives) {
+        value /= second_derivatives_sum;
+    }
+
+    float cum_sum = 0.0f;
+    size_t last_idx = candidates->size;
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        cum_sum += second_derivatives[i];
+
+        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
+        if (cum_sum > z && i >= min_keep) {
+            last_idx = i;
+            break;
         }
     }
 
-    //printf("\n");
-    //for (int i = 0; i < (int) 10; i++) {
-    //    printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
-    //}
-    //printf("\n\n");
-    //exit(0);
+    // Resize the output vector to keep only the tokens above the tail location
+    candidates->size = last_idx;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+
+void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    // Reference implementation:
+    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
+    if (p >= 1.0f) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Compute the softmax of logits and calculate entropy
+    llama_sample_softmax(nullptr, candidates);
+
+    float entropy = 0.0f;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
+    }
+
+    // Compute the absolute difference between negative log probability and entropy for each candidate
+    std::vector<float> shifted_scores;
+    for (size_t i = 0; i < candidates->size; ++i) {
+        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
+        shifted_scores.push_back(shifted_score);
+    }
+
+    // Sort tokens based on the shifted_scores and their corresponding indices
+    std::vector<size_t> indices(candidates->size);
+    std::iota(indices.begin(), indices.end(), 0);
+
+    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
+        return shifted_scores[a] < shifted_scores[b];
+    });
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = indices.size();
+
+    for (size_t i = 0; i < indices.size(); ++i) {
+        size_t idx = indices[i];
+        cum_sum += candidates->data[idx].p;
+
+        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
+        if (cum_sum > p && i >= min_keep - 1) {
+            last_idx = i + 1;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the locally typical tokens
+    std::vector<llama_token_data> new_candidates;
+    for (size_t i = 0; i < last_idx; ++i) {
+        size_t idx = indices[i];
+        new_candidates.push_back(candidates->data[idx]);
+    }
+
+    // Replace the data in candidates with the new_candidates data
+    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
+    candidates->size = new_candidates.size();
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    for (size_t i = 0; i < candidates_p->size; ++i) {
+        candidates_p->data[i].logit /= temp;
+    }
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty) {
+    if (last_tokens_size == 0 || penalty == 1.0f) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    for (size_t i = 0; i < candidates->size; ++i) {
+        auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
+        if (token_iter == last_tokens + last_tokens_size) {
+            continue;
+        }
+
+        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
+        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
+        if (candidates->data[i].logit <= 0) {
+            candidates->data[i].logit *= penalty;
+        } else {
+            candidates->data[i].logit /= penalty;
+        }
+    }
+
+    candidates->sorted = false;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
+    if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
+        return;
+    }
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Create a frequency map to count occurrences of each token in last_tokens
+    std::unordered_map<llama_token, int> token_count;
+    for (size_t i = 0; i < last_tokens_size; ++i) {
+        token_count[last_tokens_p[i]]++;
+    }
+
+    // Apply frequency and presence penalties to the candidates
+    for (size_t i = 0; i < candidates->size; ++i) {
+        auto token_iter = token_count.find(candidates->data[i].id);
+        if (token_iter == token_count.end()) {
+            continue;
+        }
+
+        int count = token_iter->second;
+        candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
+    }
+
+    candidates->sorted = false;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
+
+llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
+    assert(ctx);
+    auto N = float(llama_n_vocab(ctx));
+    int64_t t_start_sample_us;
+    t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax(nullptr, candidates);
+
+    // Estimate s_hat using the most probable m tokens
+    float s_hat = 0.0;
+    float sum_ti_bi = 0.0;
+    float sum_ti_sq = 0.0;
+    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
+        float t_i = logf(float(i + 2) / float(i + 1));
+        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
+        sum_ti_bi += t_i * b_i;
+        sum_ti_sq += t_i * t_i;
+    }
+    s_hat = sum_ti_bi / sum_ti_sq;
+
+    // Compute k from the estimated s_hat and target surprise value
+    float epsilon_hat = s_hat - 1;
+    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
+
+    // Sample the next word X using top-k sampling
+    llama_sample_top_k(nullptr, candidates, int(k));
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+    llama_token X = llama_sample_token(ctx, candidates);
+    t_start_sample_us = ggml_time_us();
+
+    // Compute error as the difference between observed surprise and target surprise value
+    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return candidate.id == X;
+    }));
+    float observed_surprise = -log2f(candidates->data[X_idx].p);
+    float e = observed_surprise - tau;
+
+    // Update mu using the learning rate and error
+    *mu = *mu - eta * e;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+        ctx->n_sample++;
+    }
+    return X;
+}
+
+llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
+    assert(ctx);
+    int64_t t_start_sample_us;
+    t_start_sample_us = ggml_time_us();
+
+    llama_sample_softmax(ctx, candidates);
+
+    // Truncate the words with surprise values greater than mu
+    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return -log2f(candidate.p) > *mu;
+    }));
+
+    // Normalize the probabilities of the remaining words
+    llama_sample_softmax(ctx, candidates);
+
+    // Sample the next word X from the remaining words
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+    llama_token X = llama_sample_token(ctx, candidates);
+    t_start_sample_us = ggml_time_us();
+
+    // Compute error as the difference between observed surprise and target surprise value
+    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
+        return candidate.id == X;
+    }));
+    float observed_surprise = -log2f(candidates->data[X_idx].p);
+    float e = observed_surprise - tau;
+
+    // Update mu using the learning rate and error
+    *mu = *mu - eta * e;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+    return X;
+}
+
+llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    // Find max element
+    auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+        return a.logit < b.logit;
+    });
+
+    llama_token result = max_iter->id;
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+        ctx->n_sample++;
+    }
+    return result;
+}
+
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+    assert(ctx);
+    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sample_softmax(nullptr, candidates);
+
+    std::vector<float> probs;
+    probs.reserve(candidates->size);
+    for (size_t i = 0; i < candidates->size; ++i) {
+        probs.push_back(candidates->data[i].p);
+    }
 
     std::discrete_distribution<> dist(probs.begin(), probs.end());
+    auto & rng = ctx->rng;
     int idx = dist(rng);
 
-    return logits_id[idx].second;
+    llama_token result = candidates->data[idx].id;
+
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    ctx->n_sample++;
+    return result;
 }
 
 //
@@ -2348,33 +2642,8 @@ llama_token llama_token_eos() {
     return 2;
 }
 
-llama_token llama_sample_top_p_top_k(
-          llama_context * ctx,
-      const llama_token * last_n_tokens_data,
-                    int   last_n_tokens_size,
-                    int   top_k,
-                  float   top_p,
-                  float   temp,
-                  float   repeat_penalty) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    llama_token result = 0;
-
-    // TODO: avoid this ...
-    const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
-
-    result = llama_sample_top_p_top_k(
-            *ctx,
-            last_n_tokens,
-            top_k,
-            top_p,
-            temp,
-            repeat_penalty);
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    ctx->n_sample++;
-
-    return result;
+llama_token llama_token_nl() {
+    return 13;
 }
 
 
diff --git a/llama.h b/llama.h
index 936c521393c7383f9fe9fe19ca1a2b880a6b83f7..34a8f5b3ca52cde035c4e7ee90cdd1a8b33f1e22 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -39,12 +39,16 @@ extern "C" {
 
     typedef struct llama_token_data {
         llama_token id;  // token id
-
+        float logit; // log-odds of the token
         float p;     // probability of the token
-        float plog;  // log probability of the token
-
     } llama_token_data;
 
+    typedef struct llama_token_data_array {
+        llama_token_data * data;
+        size_t size;
+        bool sorted;
+    } llama_token_data_array;
+
     typedef void (*llama_progress_callback)(float progress, void *ctx);
 
     struct llama_context_params {
@@ -181,16 +185,52 @@ extern "C" {
     // Special tokens
     LLAMA_API llama_token llama_token_bos();
     LLAMA_API llama_token llama_token_eos();
+    LLAMA_API llama_token llama_token_nl();
+
+    // Sampling functions
+
+    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+    LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
+
+    /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+    LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
+
+    /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+    LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
+
+    /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+    LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
+
+    /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+    LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+
+    /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
+    LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
+
+    /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
+    LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+    LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
+
+    /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+    /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
+    /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+    /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+    /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+    /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+    LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
+
+    /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+    /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
+    /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+    /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+    /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+    LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
+
+    /// @details Selects the token with the highest probability.
+    LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
 
-    // TODO: improve the last_n_tokens interface ?
-    LLAMA_API llama_token llama_sample_top_p_top_k(
-       struct llama_context * ctx,
-          const llama_token * last_n_tokens_data,
-                        int   last_n_tokens_size,
-                        int   top_k,
-                      float   top_p,
-                      float   temp,
-                      float   repeat_penalty);
+    /// @details Randomly selects a token from the candidates based on their probabilities.
+    LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
 
     // Performance information
     LLAMA_API void llama_print_timings(struct llama_context * ctx);
index 81eadbc4db0a4e344f5f83d99b862297637292cd..645648585ab3dd58e14698f7fc8cba68cec3f08f 100644 (file)
@@ -8,4 +8,5 @@ endfunction()
 # llama_add_test(test-double-float.c) # SLOW
 llama_add_test(test-quantize-fns.cpp)
 llama_add_test(test-quantize-perf.cpp)
+llama_add_test(test-sampling.cpp)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
new file mode 100644 (file)
index 0000000..7eee4f6
--- /dev/null
@@ -0,0 +1,199 @@
+#include "llama.h"
+#include "ggml.h"
+#include <cassert>
+#include <cmath>
+#include <numeric>
+#include <cassert>
+#include <iostream>
+#include <vector>
+#include <algorithm>
+
+
+void dump(const llama_token_data_array * candidates) {
+    for (size_t i = 0; i < candidates->size; i++) {
+        printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
+    }
+}
+
+#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
+
+
+void test_top_k(const std::vector<float> & probs,
+                const std::vector<float> & expected_probs,
+                int k) {
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    llama_sample_softmax(nullptr, &candidates_p);
+    DUMP(&candidates_p);
+    llama_sample_top_k(nullptr, &candidates_p, k);
+    DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
+    }
+}
+
+
+void test_top_p(const std::vector<float> & probs,
+                const std::vector<float> & expected_probs,
+                float p) {
+
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    llama_sample_softmax(nullptr, &candidates_p);
+    DUMP(&candidates_p);
+    llama_sample_top_p(nullptr, &candidates_p, p);
+    DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    }
+}
+
+
+void test_tfs(const std::vector<float> & probs,
+                const std::vector<float> & expected_probs,
+                float z) {
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    DUMP(&candidates_p);
+    llama_sample_tail_free(nullptr, &candidates_p, z);
+    DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    }
+}
+
+
+void test_typical(const std::vector<float> & probs,
+                const std::vector<float> & expected_probs,
+                float p) {
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    DUMP(&candidates_p);
+    llama_sample_typical(nullptr, &candidates_p, p);
+    DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    }
+}
+
+
+void test_repetition_penalty(
+                const std::vector<float> & probs,
+                const std::vector<llama_token> & last_tokens,
+                const std::vector<float> & expected_probs,
+                float penalty) {
+    assert(probs.size() == expected_probs.size());
+
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    llama_sample_softmax(nullptr, &candidates_p);
+    DUMP(&candidates_p);
+    llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty);
+    llama_sample_softmax(nullptr, &candidates_p);
+    DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
+    }
+}
+
+
+void test_frequency_presence_penalty(
+                const std::vector<float> & probs,
+                const std::vector<llama_token> & last_tokens,
+                const std::vector<float> & expected_probs,
+                float alpha_frequency, float alpha_presence) {
+    assert(probs.size() == expected_probs.size());
+
+    size_t n_vocab = probs.size();
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        float logit = log(probs[token_id]);
+        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    llama_sample_softmax(nullptr, &candidates_p);
+    // DUMP(&candidates_p);
+    llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
+    llama_sample_softmax(nullptr, &candidates_p);
+    // DUMP(&candidates_p);
+
+    assert(candidates_p.size == expected_probs.size());
+    for (size_t i = 0; i < candidates_p.size; i++) {
+        assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    }
+}
+
+int main(void) {
+    ggml_time_init();
+
+    test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1);
+    test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3);
+
+    test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0);
+    test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7);
+    test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1);
+
+    test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25);
+    test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75);
+    test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99);
+
+    test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5);
+    test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5);
+
+    test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0);
+    test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0);
+    test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0);
+
+    test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0},             {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0);
+    test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2},       {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0);
+    test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0);
+
+    printf("OK\n");
+}