]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling: add Top-nσ sampler (#11223)
authorVinesh Janarthanan <redacted>
Thu, 13 Feb 2025 06:45:57 +0000 (00:45 -0600)
committerGitHub <redacted>
Thu, 13 Feb 2025 06:45:57 +0000 (08:45 +0200)
* initial sampling changes:

* completed top nsigma sampler implementation

* apply parameter to only llama-cli

* updated readme

* added tests and fixed nsigma impl

* cleaned up pr

* format

* format

* format

* removed commented tests

* cleanup pr and remove explicit floats

* added top-k sampler to improve performance

* changed sigma to float

* fixed string format to float

* Update src/llama-sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update common/sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update src/llama-sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update src/llama-sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update src/llama-sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update src/llama-sampling.cpp

Co-authored-by: Georgi Gerganov <redacted>
* added llama_sampler_init

---------

Co-authored-by: Georgi Gerganov <redacted>
common/arg.cpp
common/common.h
common/sampling.cpp
examples/main/README.md
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index f7af09c7166cd1430132de45df40137e4b506f43..f3027af4db4d85cf28e6a0dc99a858e01e424b83 100644 (file)
@@ -946,6 +946,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.sampling.min_p = std::stof(value);
         }
     ).set_sparam());
+    add_opt(common_arg(
+        {"--top-nsigma"}, "N",
+        string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
+        [](common_params & params, const std::string & value) {
+            params.sampling.top_n_sigma = std::stof(value);
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
     add_opt(common_arg(
         {"--xtc-probability"}, "N",
         string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
index 177c98f0fdfd112f16cc6abfa4b8c0fd4ad2b7d1..17f5ea8683d690554e1ce01f86c7da8067caf8db 100644 (file)
@@ -140,6 +140,7 @@ struct common_params_sampling {
     int32_t dry_allowed_length = 2;     // tokens extending repetitions beyond this receive penalty
     int32_t dry_penalty_last_n = -1;    // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
     int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float   top_n_sigma        = -1.00f;// -1.0 = disabled
     float   mirostat_tau       = 5.00f; // target entropy
     float   mirostat_eta       = 0.10f; // learning rate
     bool    ignore_eos         = false;
index e4b21ca1011dddbd1da2a6f8ba6013beaee43605..21e15ee8413a356ed468b4411c3af27c6d251650 100644 (file)
@@ -134,11 +134,11 @@ std::string common_params_sampling::print() const {
     snprintf(result, sizeof(result),
             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
             "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
-            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
+            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
             penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
             dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
-            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
+            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
             mirostat, mirostat_eta, mirostat_tau);
 
     return std::string(result);
@@ -188,45 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 params.logit_bias.data()));
 
     if (params.mirostat == 0) {
-        for (const auto & cnstr : params.samplers) {
-            switch (cnstr) {
-                case COMMON_SAMPLER_TYPE_DRY:
-                    {
-                        std::vector<const char *> c_breakers;
-                        c_breakers.reserve(params.dry_sequence_breakers.size());
-                        for (const auto & str : params.dry_sequence_breakers) {
-                            c_breakers.push_back(str.c_str());
+        if (params.top_n_sigma >= 0) {
+            llama_sampler_chain_add(result->chain, llama_sampler_init_top_k        (params.top_k));
+            llama_sampler_chain_add(result->chain, llama_sampler_init_temp         (params.temp));
+            llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma  (params.top_n_sigma));
+        } else {
+            for (const auto & cnstr : params.samplers) {
+                switch (cnstr) {
+                    case COMMON_SAMPLER_TYPE_DRY:
+                        {
+                            std::vector<const char *> c_breakers;
+                            c_breakers.reserve(params.dry_sequence_breakers.size());
+                            for (const auto & str : params.dry_sequence_breakers) {
+                                c_breakers.push_back(str.c_str());
+                            }
+
+                            llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                         }
-
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
-                    }
-                    break;
-                case COMMON_SAMPLER_TYPE_TOP_K:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
-                    break;
-                case COMMON_SAMPLER_TYPE_TOP_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
-                    break;
-                case COMMON_SAMPLER_TYPE_MIN_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
-                    break;
-                case COMMON_SAMPLER_TYPE_XTC:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
-                    break;
-                case COMMON_SAMPLER_TYPE_TYPICAL_P:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
-                    break;
-                case COMMON_SAMPLER_TYPE_TEMPERATURE:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
-                    break;
-                case COMMON_SAMPLER_TYPE_INFILL:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (vocab));
-                    break;
-                case COMMON_SAMPLER_TYPE_PENALTIES:
-                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
-                    break;
-                default:
-                    GGML_ASSERT(false && "unknown sampler type");
+                        break;
+                    case COMMON_SAMPLER_TYPE_TOP_K:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
+                        break;
+                    case COMMON_SAMPLER_TYPE_TOP_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
+                        break;
+                    case COMMON_SAMPLER_TYPE_MIN_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
+                        break;
+                    case COMMON_SAMPLER_TYPE_XTC:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                        break;
+                    case COMMON_SAMPLER_TYPE_TYPICAL_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
+                        break;
+                    case COMMON_SAMPLER_TYPE_TEMPERATURE:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                        break;
+                    case COMMON_SAMPLER_TYPE_INFILL:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (vocab));
+                        break;
+                    case COMMON_SAMPLER_TYPE_PENALTIES:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
+                        break;
+                    default:
+                        GGML_ASSERT(false && "unknown sampler type");
+                }
             }
         }
         llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
index ceaed42f649df595ea390b378154496488ade0ea..ea71591bd939d571562a00f0ae8034a584344bb7 100644 (file)
@@ -265,6 +265,14 @@ Being experimental and unique, XTC is disabled by default. The recommended combi
 
 Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
 
+### Top-nσ Sampling
+
+-   `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled).
+
+Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space.
+
+Example usage: `--top-nsigma 1`
+
 ### Logit Bias
 
 -   `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
index 3784f7d3950e502cb10a5d8a4bb5a3e83687f366..1f5f3a09b311e0d6af0746e1414a7b8f98bb9a6f 100644 (file)
@@ -1172,6 +1172,9 @@ extern "C" {
     /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
     LLAMA_API struct llama_sampler * llama_sampler_init_xtc        (float   p, float   t,     size_t min_keep, uint32_t seed);
 
+    /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float   n);
+
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
index 990b6129746dee801f7742935cf82c104d476cf7..f40bf2db83a80ffe6b82ec79798fbf0c49dad42c 100644 (file)
@@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties(
     );
 }
 
+// top-n-sigma
+
+struct llama_sampler_top_n_sigma {
+    const float n;
+};
+
+static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
+    return "top-n-sigma";
+}
+
+static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
+
+    // find max logit and calculate mean
+    float max = cur_p->data[0].logit;
+    float logits_sum = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit > max) {
+            max = cur_p->data[i].logit;
+        }
+        logits_sum += cur_p->data[i].logit;
+    }
+    float mean = logits_sum/cur_p->size;
+
+    // calculate standard deviation
+    float acc = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        acc += pow(cur_p->data[i].logit - mean, 2);
+    }
+    float std = sqrt(acc/cur_p->size);
+
+    //apply mask
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit < max - (ctx->n * std)) {
+            cur_p->data[i].logit = -INFINITY;
+        }
+    }
+    llama_sampler_softmax_impl(cur_p);
+}
+
+static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
+    return llama_sampler_init_top_n_sigma(ctx->n);
+}
+
+static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_n_sigma *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
+    /* .name   = */ llama_sampler_top_n_sigma_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_top_n_sigma_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_top_n_sigma_clone,
+    /* .free   = */ llama_sampler_top_n_sigma_free,
+};
+
+struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
+    return llama_sampler_init(
+        /* .iface = */ &llama_sampler_top_n_sigma_i,
+        /* .ctx   = */ new llama_sampler_top_n_sigma {
+            /* .n = */ n,
+        }
+    );
+}
+
 // DRY
 
 struct llama_sampler_dry {
index 61bd6785044ea152b15f6998f8ddb28b822544e7..f1f87d454d464afd8992661a490343dc6e949844 100644 (file)
@@ -181,6 +181,17 @@ static void test_dry(
     tester.check();
 }
 
+static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_n_sigma(n));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
 static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
 ) {
     sampler_tester tester(n_vocab);
@@ -348,6 +359,10 @@ int main(void) {
     test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
     test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
 
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f);
+    test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
+
     test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
     test_sampler_queue(10000, "k",     1, 1.0f, 1.0f);
     test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);