]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : Integrate Top-nσ into main sampling chain (and add it to the server) ...
authoroobabooga <redacted>
Mon, 5 May 2025 20:12:19 +0000 (17:12 -0300)
committerGitHub <redacted>
Mon, 5 May 2025 20:12:19 +0000 (22:12 +0200)
* sampling: add Top-nσ sampler to `llama-server` and sampler ordering

* revert: sampler ordering

* revert: VS' crappy auto-formatting

* revert: VS' crappy auto-formatting pt.2

* revert: my crappy eye sight...

* sampling: add XTC to Top-nσ sampler chain

* sampling: add Dyna. Temp. to Top-nσ sampler chain

* sampling: actually remove Top-nσ from sampler(oops)

* Integrate top_n_sigma into main sampler chain

* Define COMMON_SAMPLER_TYPE_TOP_N_SIGMA

* Formatting

* Lint

* Exit early in the sampler if nsigma < 0

---------

Co-authored-by: CasualAutopsy <redacted>
common/common.h
common/sampling.cpp
src/llama-sampling.cpp
tools/server/server.cpp

index 416939da9a602a175e541f93b2b344acce9a9a1d..dfd6e20933f15995384cbcb8c4343f1f04313a7b 100644 (file)
@@ -96,6 +96,7 @@ enum common_sampler_type {
     COMMON_SAMPLER_TYPE_XTC         = 8,
     COMMON_SAMPLER_TYPE_INFILL      = 9,
     COMMON_SAMPLER_TYPE_PENALTIES   = 10,
+    COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -161,6 +162,7 @@ struct common_params_sampling {
     std::vector<enum common_sampler_type> samplers = {
         COMMON_SAMPLER_TYPE_PENALTIES,
         COMMON_SAMPLER_TYPE_DRY,
+        COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
         COMMON_SAMPLER_TYPE_TOP_K,
         COMMON_SAMPLER_TYPE_TYPICAL_P,
         COMMON_SAMPLER_TYPE_TOP_P,
index 1735b650183c83fffa1e7414e0c7b07da911c0b9..bbaec5b80adebd3fd0a2158d0a0c2e16a5da6037 100644 (file)
@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 params.logit_bias.data()));
 
     if (params.mirostat == 0) {
-        if (params.top_n_sigma >= 0) {
-            llama_sampler_chain_add(result->chain, llama_sampler_init_top_k        (params.top_k));
-            llama_sampler_chain_add(result->chain, llama_sampler_init_temp         (params.temp));
-            llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma  (params.top_n_sigma));
-        } else {
-            for (const auto & cnstr : params.samplers) {
-                switch (cnstr) {
-                    case COMMON_SAMPLER_TYPE_DRY:
-                        {
-                            std::vector<const char *> c_breakers;
-                            c_breakers.reserve(params.dry_sequence_breakers.size());
-                            for (const auto & str : params.dry_sequence_breakers) {
-                                c_breakers.push_back(str.c_str());
-                            }
-
-                            llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
+        for (const auto & cnstr : params.samplers) {
+            switch (cnstr) {
+                case COMMON_SAMPLER_TYPE_DRY:
+                    {
+                        std::vector<const char *> c_breakers;
+                        c_breakers.reserve(params.dry_sequence_breakers.size());
+                        for (const auto & str : params.dry_sequence_breakers) {
+                            c_breakers.push_back(str.c_str());
                         }
-                        break;
-                    case COMMON_SAMPLER_TYPE_TOP_K:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TOP_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_MIN_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_XTC:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TYPICAL_P:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
-                        break;
-                    case COMMON_SAMPLER_TYPE_TEMPERATURE:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
-                        break;
-                    case COMMON_SAMPLER_TYPE_INFILL:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (vocab));
-                        break;
-                    case COMMON_SAMPLER_TYPE_PENALTIES:
-                        llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
-                        break;
-                    default:
-                        GGML_ASSERT(false && "unknown sampler type");
-                }
+
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
+                    }
+                    break;
+                case COMMON_SAMPLER_TYPE_TOP_K:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k       (params.top_k));
+                    break;
+                case COMMON_SAMPLER_TYPE_TOP_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p       (params.top_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
+                    break;
+                case COMMON_SAMPLER_TYPE_MIN_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p       (params.min_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_XTC:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc         (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
+                    break;
+                case COMMON_SAMPLER_TYPE_TYPICAL_P:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical     (params.typ_p, params.min_keep));
+                    break;
+                case COMMON_SAMPLER_TYPE_TEMPERATURE:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext    (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                    break;
+                case COMMON_SAMPLER_TYPE_INFILL:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill      (vocab));
+                    break;
+                case COMMON_SAMPLER_TYPE_PENALTIES:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties   (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
+                    break;
+                default:
+                    GGML_ASSERT(false && "unknown sampler type");
             }
         }
         llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
         case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
         case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
+        case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
         case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
         case COMMON_SAMPLER_TYPE_XTC:         return 'x';
@@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
         case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
         case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
+        case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
         case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
         case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
@@ -504,6 +503,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
         { "dry",         COMMON_SAMPLER_TYPE_DRY },
         { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
         { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
+        { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
         { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
         { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
         { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -517,6 +517,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
     std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
         { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
         { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
+        { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
         { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
         { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
         { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -552,6 +553,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
index c0a5f9340d5851beade2deb00ef41bea2e38bc7e..0c9c6a3102929a72109172379ab86a4ddd58e56d 100644 (file)
@@ -1750,6 +1750,10 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
 static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
 
+    if (ctx->n < 0.0f) {
+        return;
+    }
+
     // find max logit and calculate mean
     float max = cur_p->data[0].logit;
     float logits_sum = 0;
index c580ec123299cebcbd61db562fa9f76a2f5044e5..e0e99eafcdf5ce83c91f61bcfee6e5fef3a7dd54 100644 (file)
@@ -146,6 +146,7 @@ struct slot_params {
             {"top_k",                     sampling.top_k},
             {"top_p",                     sampling.top_p},
             {"min_p",                     sampling.min_p},
+            {"top_n_sigma",               sampling.top_n_sigma},
             {"xtc_probability",           sampling.xtc_probability},
             {"xtc_threshold",             sampling.xtc_threshold},
             {"typical_p",                 sampling.typ_p},
@@ -248,6 +249,7 @@ struct server_task {
         params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
         params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
         params.sampling.min_p              = json_value(data, "min_p",              defaults.sampling.min_p);
+        params.sampling.top_n_sigma        = json_value(data, "top_n_sigma",        defaults.sampling.top_n_sigma);
         params.sampling.xtc_probability    = json_value(data, "xtc_probability",    defaults.sampling.xtc_probability);
         params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",      defaults.sampling.xtc_threshold);
         params.sampling.typ_p              = json_value(data, "typical_p",          defaults.sampling.typ_p);