]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : add XTC sampler (#9742)
authorMaggotHATE <redacted>
Tue, 15 Oct 2024 10:54:55 +0000 (15:54 +0500)
committerGitHub <redacted>
Tue, 15 Oct 2024 10:54:55 +0000 (12:54 +0200)
* Initial XTC commit

Adds XTC sampler, not activated by default, but recommended settings by default.

* Cleanup

* Simplified chances calculation

To be more inline with the original implementation, chance is calculated once at the beginning.

* First fixes by comments

Still need to look into sorting

* Fixed trailing backspaces

* Fixed RNG to be reproduceable

Thanks to @slaren for directions

* Fixed forgotten header

* Moved `min_keep`

Moved from conditions to a simple check at the end.

* Fixed broken randomization

Thanks to @slaren for explanation

* Swapped sorting for a custom algorithm

Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable.

* Algorithm rework

1. Scan token from top till the first non-penalizable
2. Remove the last captured token (the least probable above threshold)
3. Shift all tokens to override the remaining penalizable
4. Penalize and put them at the the bottom.

* Added XTC to `test-sampling`

* Simplified algorithm and more tests

* Updated info in common and args

* Merged back lost commits in common and arg

* Update dump info in common

* Fixed incorrect min_keep check

* Added XTC to README

* Renamed parameters, fixed info and defaults

* probability is at 0 by default, but XTC is included in sampling queue
* threshold higher than 0.5 switches XTC off

* Initial server support

* Added XTC to server UIs

* Fixed labels in old server UI

* Made algorithm safer and more readable

* Removed xtc_threshold_max

* Fixed arg after update

* Quick fixes by comments

* Simplified algorithm since threshold_max is removed

* Renamed random distribution

* Fixed tests and outdated README

* Small fixes

common/arg.cpp
common/common.cpp
common/common.h
common/sampling.cpp
examples/main/README.md
examples/server/public/index-new.html
examples/server/public/index.html
examples/server/server.cpp
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index 8969fc1073c8501aa8195ac54a631db2f535573d..d6a8e1f6ff0bf8a7a4ce18b967d1decd174754fd 100644 (file)
@@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.sparams.tfs_z = std::stof(value);
         }
     ).set_sparam());
+    add_opt(common_arg(
+        {"--xtc-probability"}, "N",
+        string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
+        [](common_params & params, const std::string & value) {
+            params.sparams.xtc_probability = std::stof(value);
+        }
+    ).set_sparam());
+    add_opt(common_arg(
+        {"--xtc-threshold"}, "N",
+        string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
+        [](common_params & params, const std::string & value) {
+            params.sparams.xtc_threshold = std::stof(value);
+        }
+    ).set_sparam());
     add_opt(common_arg(
         {"--typical"}, "N",
         string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
index 451307b554b6b109954b07d2429eb521d75e0ab0..c08f01b429056e96ef90a5ae15ecaca2a00df35d 100644 (file)
@@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
     fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
     fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
     fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
+    fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
+    fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
     fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
     fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
index 5507b1c59bb182ab23711cee9f3aa35658a8faf5..df2ee6bd43a3312cc74f94ab82896e833b44efe1 100644 (file)
@@ -90,6 +90,8 @@ enum common_sampler_type {
     COMMON_SAMPLER_TYPE_TFS_Z       = 4,
     COMMON_SAMPLER_TYPE_TYPICAL_P   = 5,
     COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
+    COMMON_SAMPLER_TYPE_XTC         = 7,
+
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -108,6 +110,8 @@ struct common_sampler_params {
     int32_t top_k             = 40;    // <= 0 to use vocab size
     float   top_p             = 0.95f; // 1.0 = disabled
     float   min_p             = 0.05f; // 0.0 = disabled
+    float   xtc_probability   = 0.00f; // 0.0 = disabled
+    float   xtc_threshold     = 0.10f; // > 0.5 disables XTC
     float   tfs_z             = 1.00f; // 1.0 = disabled
     float   typ_p             = 1.00f; // typical_p, 1.0 = disabled
     float   temp              = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
@@ -124,12 +128,14 @@ struct common_sampler_params {
     bool    ignore_eos        = false;
     bool    no_perf           = false; // disable performance metrics
 
+
     std::vector<enum common_sampler_type> samplers = {
         COMMON_SAMPLER_TYPE_TOP_K,
         COMMON_SAMPLER_TYPE_TFS_Z,
         COMMON_SAMPLER_TYPE_TYPICAL_P,
         COMMON_SAMPLER_TYPE_TOP_P,
         COMMON_SAMPLER_TYPE_MIN_P,
+        COMMON_SAMPLER_TYPE_XTC,
         COMMON_SAMPLER_TYPE_TEMPERATURE
     };
 
index cd49ade69af3680fa49f23daba4f6b2646b3bf55..fb95bcd3bf2b098187c59fe8caa8a5c5c920cefb 100644 (file)
@@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
 
     snprintf(result, sizeof(result),
             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
-            "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
+            "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
             penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
-            top_k, tfs_z, top_p, min_p, typ_p, temp,
+            top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
             mirostat, mirostat_eta, mirostat_tau);
 
     return std::string(result);
@@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                     case COMMON_SAMPLER_TYPE_MIN_P:
                         llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                         break;
+                    case COMMON_SAMPLER_TYPE_XTC:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                        break;
                     case COMMON_SAMPLER_TYPE_TFS_Z:
                         llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
                         break;
@@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
         case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
+        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
         default : return '?';
     }
 }
@@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
         case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
+        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
         default : return "";
     }
 }
@@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
         { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
         { "tfs_z",       COMMON_SAMPLER_TYPE_TFS_Z },
         { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
+        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
     };
 
     // since samplers names are written multiple ways
@@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
-        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC }
     };
 
     std::vector<common_sampler_type> samplers;
index f0c3031ab130e395fbb2e8f0334b77b7e25cbb14..620934dad4ad559fd38d28dbe0a32deb12acfbf9 100644 (file)
@@ -241,6 +241,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
 
 Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
 
+### XTC Sampling
+
+-   `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
+-   `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
+
+Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
+
+By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
+
+Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
+
+Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
+
 ### Logit Bias
 
 -   `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
index c87dd8f1e1d32b8fb3b648625fdb89631b54bcfb..ad4183cd928f7f39aeb608ea4b681fa0e5de7212 100644 (file)
@@ -43,6 +43,8 @@
       top_k: 0, // <= 0 to use vocab size
       top_p: 1.0, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
+      xtc_probability: 0.0, // 0 = disabled;
+      xtc_threshold: 0.1, // > 0.5 disables XTC;
       tfs_z: 1.0, // 1.0 = disabled
       typical_p: 1.0, // 1.0 = disabled
       presence_penalty: 0.0, // 0.0 = disabled
@@ -836,6 +838,8 @@ return html`
           ${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
           ${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
           ${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
+          ${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
+          ${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
           ${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
         </fieldset>
 
@@ -1132,6 +1136,8 @@ document.addEventListener('DOMContentLoaded', (event) => {
   const snapSettings = {
     temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
     min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
+    xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
+    xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
     top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
     tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
     typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
index 07fec6a38bbcdeb93a806a344e95e15471cffd8c..88065705fb66915e62313c40644081783fa92ff3 100644 (file)
       top_k: 40, // <= 0 to use vocab size
       top_p: 0.95, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled
+      xtc_probability: 0.0, // 0 = disabled;
+      xtc_threshold: 0.1, // > 0.5 disables XTC;
       tfs_z: 1.0, // 1.0 = disabled
       typical_p: 1.0, // 1.0 = disabled
       presence_penalty: 0.0, // 0.0 = disabled
               ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
               ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
               ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
+              ${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
+              ${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
             </fieldset>
             <hr />
             <fieldset class="three">
index 18bcad3f06bca0c418a5c469db424db4483759a2..8d4380e12f35af78248fbc3147e7345c5d1db261 100644 (file)
@@ -863,6 +863,8 @@ struct server_context {
         slot.sparams.top_k             = json_value(data, "top_k",             default_sparams.top_k);
         slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p);
         slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p);
+        slot.sparams.xtc_probability   = json_value(data, "xtc_probability",   default_sparams.xtc_probability);
+        slot.sparams.xtc_threshold     = json_value(data, "xtc_threshold",     default_sparams.xtc_threshold);
         slot.sparams.tfs_z             = json_value(data, "tfs_z",             default_sparams.tfs_z);
         slot.sparams.typ_p             = json_value(data, "typical_p",         default_sparams.typ_p);
         slot.sparams.temp              = json_value(data, "temperature",       default_sparams.temp);
@@ -1196,6 +1198,8 @@ struct server_context {
             {"top_k",                     slot.sparams.top_k},
             {"top_p",                     slot.sparams.top_p},
             {"min_p",                     slot.sparams.min_p},
+            {"xtc_probability",           slot.sparams.xtc_probability},
+            {"xtc_threshold",             slot.sparams.xtc_threshold},
             {"tfs_z",                     slot.sparams.tfs_z},
             {"typical_p",                 slot.sparams.typ_p},
             {"repeat_last_n",             slot.sparams.penalty_last_n},
index 9110b5956c0b3672465c05332e224f1d65ed9323..92d4c70c13b876ba2253750900ee4ffe6fae877e 100644 (file)
@@ -1101,6 +1101,9 @@ extern "C" {
     /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
     LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext   (float   t, float   delta, float exponent);
 
+    /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
+    LLAMA_API struct llama_sampler * llama_sampler_init_xtc        (float   p, float   t,     size_t min_keep, uint32_t seed);
+
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
index e255a8fc4fd548e135bb00465698b1b61a5043f2..67a78c3ac4fe8eba7ca20a0c4d7c78d609571723 100644 (file)
@@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
     };
 }
 
+// xtc
+
+struct llama_sampler_xtc {
+    const float    probability;
+    const float    threshold;
+    const size_t   min_keep;
+
+    const uint32_t seed;
+    uint32_t       seed_cur;
+
+    std::mt19937   rng;
+};
+
+static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
+    return "xtc";
+}
+
+static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_xtc *) smpl->ctx;
+
+    if (ctx->probability <= 0.0f
+        || ctx->threshold > 0.5f
+        || cur_p->size < 2) {
+        return;
+    }
+
+    std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
+    float chance = distribution(ctx->rng);
+    if (chance > ctx->probability) return;
+
+    // in case it's not sorted/recalculated yet
+    llama_sampler_softmax_impl(cur_p);
+
+    int pos_last = 0;
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].p >= ctx->threshold) {
+            pos_last = i;
+        } else break;
+    }
+
+    if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
+        cur_p->data += pos_last;
+        cur_p->size -= pos_last;
+    }
+}
+
+static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
+    auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_xtc *) result->ctx;
+
+        result_ctx->rng = ctx->rng;
+    }
+
+    return result;
+}
+
+static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_xtc *) smpl->ctx;
+}
+
+static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_xtc *) smpl->ctx;
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
+}
+
+static struct llama_sampler_i llama_sampler_xtc_i = {
+    /* .name   = */ llama_sampler_xtc_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sample_xtc_apply,
+    /* .reset  = */ llama_sampler_xtc_reset,
+    /* .clone  = */ llama_sampler_xtc_clone,
+    /* .free   = */ llama_sampler_xtc_free,
+};
+
+struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
+    auto seed_cur = get_rng_seed(seed);
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_xtc_i,
+        /* .ctx   = */ new llama_sampler_xtc {
+            /* .probability   = */ p,
+            /* .threshold     = */ t,
+            /* .min_keep      = */ min_keep,
+            /* .seed          = */ seed,
+            /* .seed_cur      = */ seed_cur,
+            /* .rng           = */ std::mt19937(seed_cur),
+        },
+    };
+}
+
 // mirostat
 
 struct llama_sampler_mirostat {
index 6e021c4c70357d123783510b318100c91e901e90..1372bdf13f2f608529c1b0b67e3ffa1c9410a65c 100644 (file)
@@ -111,6 +111,28 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
     }
 }
 
+static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
+    const size_t n_vocab = probs.size();
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+        const float logit = logf(probs[token_id]);
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    APPLY(llama_sampler_init_softmax(), &cur_p);
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
+    DUMP(&cur_p);
+
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
+    }
+}
+
 static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
     const size_t n_vocab = probs.size();
 
@@ -263,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
     }
     const int64_t t_end = ggml_time_us();
     llama_sampler_free(cnstr);
-    printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
+    printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
 }
 
 #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
@@ -279,12 +301,13 @@ static void test_perf() {
         data.emplace_back(llama_token_data{i, logit, 0.0f});
     }
 
-    BENCH(llama_sampler_init_top_k    (40),      data, 32);
-    BENCH(llama_sampler_init_top_p    (0.8f, 1), data, 32);
-    BENCH(llama_sampler_init_min_p    (0.2f, 1), data, 32);
-    BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
-    BENCH(llama_sampler_init_typical  (0.5f, 1), data, 32);
-    BENCH(llama_sampler_init_softmax  (),        data, 32);
+    BENCH(llama_sampler_init_top_k    (40),                     data, 32);
+    BENCH(llama_sampler_init_top_p    (0.8f, 1),                data, 32);
+    BENCH(llama_sampler_init_min_p    (0.2f, 1),                data, 32);
+    BENCH(llama_sampler_init_tail_free(0.5f, 1),                data, 32);
+    BENCH(llama_sampler_init_typical  (0.5f, 1),                data, 32);
+    BENCH(llama_sampler_init_xtc      (1.0f, 0.1f, 1, 1),       data, 32);
+    BENCH(llama_sampler_init_softmax  (),                       data, 32);
 }
 
 int main(void) {
@@ -309,6 +332,14 @@ int main(void) {
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  0.76f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.00f);
 
+    printf("XTC should:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.1f},                                0.99f, 0.09f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.2f, 0.1f},                          0.99f, 0.19f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.3f, 0.2f, 0.1f},                    0.99f, 0.29f);
+
+    printf("XTC should not:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.4f, 0.3f, 0.2f, 0.1f},              0.99f, 0.39f);
+
     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);