]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : refactor + optimize penalties sampler (#10803)
authorGeorgi Gerganov <redacted>
Mon, 16 Dec 2024 10:31:14 +0000 (12:31 +0200)
committerGitHub <redacted>
Mon, 16 Dec 2024 10:31:14 +0000 (12:31 +0200)
* sampling : refactor + optimize penalties sampler

ggml-ci

* common : apply ignore_eos as logit bias

ggml-ci

* batched : remove penalties sampler

* params : allow penalty_last_n == -1 to be equal to context size

ggml-ci

* common : by default, move the penalties at the end of the sampling chain

ggml-ci

* common : ignore all EOG tokens

Co-authored-by: Diego Devesa <redacted>
* common : move back the penalties at the front of the sampling chain

ggml-ci

* readme : restore hint about --ignore-eos flag [no ci]

* llama : minor

ggml-ci

* webui : update

---------

Co-authored-by: Diego Devesa <redacted>
17 files changed:
common/arg.cpp
common/common.cpp
common/common.h
common/sampling.cpp
examples/batched/batched.cpp
examples/main/README.md
examples/server/README.md
examples/server/public/index.html.gz
examples/server/public_legacy/index-new.html
examples/server/public_legacy/index.html
examples/server/server.cpp
examples/server/themes/buttons-top/index.html
examples/server/themes/wild/index.html
examples/server/webui/src/main.js
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index 39bc874c818bbab74285c2858c25bd471157e3d7..3d55289c331922d39560c20f449b537d31b1198b 100644 (file)
@@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.sampling.ignore_eos = true;
         }
     ).set_sparam());
-    add_opt(common_arg(
-        {"--penalize-nl"},
-        string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
-        [](common_params & params) {
-            params.sampling.penalize_nl = true;
-        }
-    ).set_sparam());
     add_opt(common_arg(
         {"--temp"}, "N",
         string_format("temperature (default: %.1f)", (double)params.sampling.temp),
@@ -916,6 +909,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--repeat-last-n"}, "N",
         string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
         [](common_params & params, int value) {
+            if (value < -1) {
+                throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
+            }
             params.sampling.penalty_last_n = value;
             params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
         }
@@ -970,6 +966,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--dry-penalty-last-n"}, "N",
         string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
         [](common_params & params, int value) {
+            if (value < -1) {
+                throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
+            }
             params.sampling.dry_penalty_last_n = value;
         }
     ).set_sparam());
index 3adfb0329377f59e285b45ef061ceb7188449780..c0c98232ed3bba1720ec76957f3e3b44f9d9b8e9 100644 (file)
@@ -940,6 +940,25 @@ struct common_init_result common_init_from_params(common_params & params) {
         params.sampling.ignore_eos = false;
     }
 
+    if (params.sampling.ignore_eos) {
+        for (llama_token i = 0; i < llama_n_vocab(model); i++) {
+            if (llama_token_is_eog(model, i)) {
+                LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
+                params.sampling.logit_bias.push_back({i, -INFINITY});
+            }
+        }
+    }
+
+    if (params.sampling.penalty_last_n == -1) {
+        LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.penalty_last_n = llama_n_ctx(lctx);
+    }
+
+    if (params.sampling.dry_penalty_last_n == -1) {
+        LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
+    }
+
     if (params.warmup) {
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
 
index 9e47b70a4e9069b7db0322cf01bad40d570fc17f..5f556c24d933c61e8666207a5ca9c02389be2158 100644 (file)
@@ -95,6 +95,7 @@ enum common_sampler_type {
     COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
     COMMON_SAMPLER_TYPE_XTC         = 8,
     COMMON_SAMPLER_TYPE_INFILL      = 9,
+    COMMON_SAMPLER_TYPE_PENALTIES   = 10,
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -130,7 +131,6 @@ struct common_params_sampling {
     int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
     float   mirostat_tau       = 5.00f; // target entropy
     float   mirostat_eta       = 0.10f; // learning rate
-    bool    penalize_nl        = false; // consider newlines as a repeatable token
     bool    ignore_eos         = false;
     bool    no_perf            = false; // disable performance metrics
     bool    timing_per_token   = false;
@@ -139,6 +139,7 @@ struct common_params_sampling {
 
 
     std::vector<enum common_sampler_type> samplers = {
+        COMMON_SAMPLER_TYPE_PENALTIES,
         COMMON_SAMPLER_TYPE_DRY,
         COMMON_SAMPLER_TYPE_TOP_K,
         COMMON_SAMPLER_TYPE_TYPICAL_P,
@@ -193,11 +194,13 @@ struct common_params {
     float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
 
     // offload params
-    std::vector<ggml_backend_dev_t> devices;         // devices to use for offloading
-    int32_t n_gpu_layers                    =    -1; // number of layers to store in VRAM (-1 - use default)
-    int32_t main_gpu                        =     0; // the GPU that is used for scratch and small tensors
-    float   tensor_split[128]               =   {0}; // how split tensors should be distributed across GPUs
-    enum llama_split_mode        split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
+    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+
+    int32_t n_gpu_layers      = -1;  // number of layers to store in VRAM (-1 - use default)
+    int32_t main_gpu          = 0;   // the GPU that is used for scratch and small tensors
+    float   tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
+
+    enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
 
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
index 0c4699a89c8b24a10fe8638bd6aba9cde2191187..e83a971c74d5258e2eb6572e71892d066f440e55 100644 (file)
@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 params.logit_bias.size(),
                 params.logit_bias.data()));
 
-    llama_sampler_chain_add(result->chain,
-            llama_sampler_init_penalties(
-                llama_n_vocab  (model),
-                llama_token_eos(model),
-                llama_token_nl (model),
-                params.penalty_last_n,
-                params.penalty_repeat,
-                params.penalty_freq,
-                params.penalty_present,
-                params.penalize_nl,
-                params.ignore_eos));
-
     if (params.mirostat == 0) {
         for (const auto & cnstr : params.samplers) {
             switch (cnstr) {
-                    case COMMON_SAMPLER_TYPE_DRY:
+                case COMMON_SAMPLER_TYPE_DRY:
                     {
-                        std::vector<const char*> c_breakers;
+                        std::vector<const char *> c_breakers;
                         c_breakers.reserve(params.dry_sequence_breakers.size());
-                        for (const auto& str : params.dry_sequence_breakers) {
+                        for (const auto & str : params.dry_sequence_breakers) {
                             c_breakers.push_back(str.c_str());
                         }
 
                         llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                     }
-                        break;
+                    break;
                 case COMMON_SAMPLER_TYPE_TOP_K:
                     llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                     break;
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 case COMMON_SAMPLER_TYPE_INFILL:
                     llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                     break;
+                case COMMON_SAMPLER_TYPE_PENALTIES:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
+                    break;
                 default:
                     GGML_ASSERT(false && "unknown sampler type");
             }
@@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
         case COMMON_SAMPLER_TYPE_XTC:         return 'x';
         case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
+        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
         default : return '?';
     }
 }
@@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
         case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
         case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
+        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
         default : return "";
     }
 }
@@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
         { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
         { "xtc",         COMMON_SAMPLER_TYPE_XTC },
         { "infill",      COMMON_SAMPLER_TYPE_INFILL },
+        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
     };
 
     // since samplers names are written multiple ways
@@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
     };
 
     std::vector<common_sampler_type> samplers;
index ba219cd4b32ae88112824cebb16a6fbf708807d7..e2e01f2d598cde974b516b16d71445fef54f4656 100644 (file)
@@ -65,6 +65,7 @@ int main(int argc, char ** argv) {
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
     auto sparams = llama_sampler_chain_default_params();
+    sparams.no_perf = false;
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 
index 7787f7b11b81f4cc98cc610748dc93cf63a201f6..17d80a622a8bbd3db6dbeb85247c861886147d1c 100644 (file)
@@ -177,16 +177,11 @@ Example usage: `--temp 0`
 
 -   `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
 -   `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
--   `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
 
 The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
 
 The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
 
-Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
-
-Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
-
 ### DRY Repetition Penalty
 
 DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
index a9443e56b82cc8eccf717f1b1b9c49b851dfc869..63a7bf43a920d030ac662adc45af2b0e130a2477 100644 (file)
@@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
 | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
 | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
 | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
-| `--penalize-nl` | penalize newline tokens (default: false) |
 | `--temp N` | temperature (default: 0.8) |
 | `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
 | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
@@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
 
 `repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
 
-`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
-
 `presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
 
 `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
@@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
       "mirostat": 0,
       "mirostat_tau": 5.0,
       "mirostat_eta": 0.10000000149011612,
-      "penalize_nl": false,
       "stop": [],
       "max_tokens": -1,
       "n_keep": 0,
@@ -845,7 +841,6 @@ Example:
       "mirostat": 0,
       "mirostat_tau": 5.0,
       "mirostat_eta": 0.10000000149011612,
-      "penalize_nl": false,
       "stop": [],
       "max_tokens": -1,
       "n_keep": 0,
index d7816b04fad757c640aebf53123c1ef7e26dc9d4..53d33feb344cd92d89ab2ca1f59e77722d7e6cfa 100644 (file)
Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ
index 8bfa380e57388c82981be36dfa2ed4ede0f0fdec..cbfbbdf2806fa1a08307b90f8bbf7e21df3a3885 100644 (file)
@@ -39,7 +39,6 @@
       temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
       repeat_last_n: 0, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.0, // 1.0 = disabled
-      penalize_nl: false, // true only useful for infinite completion
       dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
       dry_base: 1.75,     // 0.0 = disabled
       dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
index a95f5c6df877577d0db606e19696aebd63e45177..75f39330a789d9d325628b7763d1887abf0a1601 100644 (file)
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
       dry_base: 1.75,     // 0.0 = disabled
       dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
index aa8c312590029460ffdcd2d41c5eea6f435330fb..bc0d042ae924735c6e620ad86d3993813bc5b5a8 100644 (file)
@@ -135,7 +135,6 @@ struct slot_params {
             {"mirostat",                  sampling.mirostat},
             {"mirostat_tau",              sampling.mirostat_tau},
             {"mirostat_eta",              sampling.mirostat_eta},
-            {"penalize_nl",               sampling.penalize_nl},
             {"stop",                      antiprompt},
             {"max_tokens",                n_predict}, // User configured n_predict
             {"n_keep",                    n_keep},
@@ -184,6 +183,7 @@ struct server_task {
 
     static slot_params params_from_json_cmpl(
             const llama_model * model,
+            const llama_context * ctx,
             const common_params & params_base,
             const json & data) {
         slot_params params;
@@ -226,7 +226,6 @@ struct server_task {
         params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
         params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
         params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
-        params.sampling.penalize_nl        = json_value(data, "penalize_nl",        defaults.sampling.penalize_nl);
         params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
         params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
         params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
@@ -239,8 +238,27 @@ struct server_task {
         params.speculative.n_min = std::max(params.speculative.n_min, 2);
         params.speculative.n_max = std::max(params.speculative.n_max, 0);
 
+        // TODO: add more sanity checks for the input parameters
+
+        if (params.sampling.penalty_last_n < -1) {
+            throw std::runtime_error("Error: repeat_last_n must be >= -1");
+        }
+
+        if (params.sampling.dry_penalty_last_n < -1) {
+            throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+        }
+
+        if (params.sampling.penalty_last_n == -1) {
+            // note: should be the slot's context and not the full context, but it's ok
+            params.sampling.penalty_last_n = llama_n_ctx(ctx);
+        }
+
+        if (params.sampling.dry_penalty_last_n == -1) {
+            params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+        }
+
         if (params.sampling.dry_base < 1.0f) {
-           params.sampling.dry_base = defaults.sampling.dry_base;
+            params.sampling.dry_base = defaults.sampling.dry_base;
         }
 
         // sequence breakers for DRY
@@ -1469,7 +1487,7 @@ struct server_context {
         n_ctx = llama_n_ctx(ctx);
 
         add_bos_token = llama_add_bos_token(model);
-        has_eos_token = !llama_add_eos_token(model);
+        has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
 
         if (!params_base.speculative.model.empty()) {
             SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@@ -3381,7 +3399,7 @@ int main(int argc, char ** argv) {
                 task.index = i;
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
-                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
+                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat
index 2797c37c96456239906caf6b01d8f0129ac9c867..3fb88fcc88d319e16908314a26e19f641150d65f 100644 (file)
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       top_k: 40, // <= 0 to use vocab size
       top_p: 0.95, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
index dbe23c40241556f069258c3ec046bcd37d89cbb6..73f36d4b29fdd884baf0a2e91c193a7283449ef4 100644 (file)
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       top_k: 40, // <= 0 to use vocab size
       top_p: 0.95, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
index a16df1cccbfc379232d7d158cef5b55b4a45a7ee..441fd73519a6b888545ce9e6147184a83c26f0c6 100644 (file)
@@ -33,7 +33,7 @@ const CONFIG_DEFAULT = {
   systemMessage: 'You are a helpful assistant.',
   showTokensPerSecond: false,
   // make sure these default values are in sync with `common.h`
-  samplers: 'dkypmxt',
+  samplers: 'edkypmxt',
   temperature: 0.8,
   dynatemp_range: 0.0,
   dynatemp_exponent: 1.0,
index c67988a33a78a480dd5a4ef168979ee2df058313..efbb27d21523acbfa8a3253998f05aa062becca1 100644 (file)
@@ -1139,16 +1139,12 @@ extern "C" {
                           const char * grammar_str,
                           const char * grammar_root);
 
+    /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
     LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
-                             int32_t   n_vocab,         // llama_n_vocab()
-                         llama_token   special_eos_id,  // llama_token_eos()
-                         llama_token   linefeed_id,     // llama_token_nl()
-                             int32_t   penalty_last_n,  // last n tokens to penalize (0 = disable penalty, -1 = context size)
-                               float   penalty_repeat,  // 1.0 = disabled
-                               float   penalty_freq,    // 0.0 = disabled
-                               float   penalty_present, // 0.0 = disabled
-                                bool   penalize_nl,     // consider newlines as a repeatable token
-                                bool   ignore_eos);     // ignore the end-of-sequence token
+                             int32_t   penalty_last_n,   // last n tokens to penalize (0 = disable penalty, -1 = context size)
+                               float   penalty_repeat,   // 1.0 = disabled
+                               float   penalty_freq,     // 0.0 = disabled
+                               float   penalty_present); // 0.0 = disabled
 
     ///  @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
     LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
index fd8ca8a9edfac5298c7550ce4f821c2bc78e194e..bebff77cfa09d9e9e892e909158fcdcfdcbe7312 100644 (file)
@@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
 // penalties
 
 struct llama_sampler_penalties {
-    const int32_t     n_vocab;
-    const llama_token special_eos_id;
-    const llama_token linefeed_id;
-
     const int32_t penalty_last_n;
     const float   penalty_repeat;
     const float   penalty_freq;
     const float   penalty_present;
 
-    const bool    penalize_nl;
-    const bool    ignore_eos;
-
     ring_buffer<llama_token> prev;
+
+    // a frequency map to count token occurrences
+    std::unordered_map<llama_token, int> token_count;
 };
 
 static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
         return;
     }
 
-    ctx->prev.push_back(token);
-}
-
-static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    ctx->token_count[token]++;
 
-    if (ctx->ignore_eos) {
-        assert(ctx->special_eos_id >= 0);
+    // if the ring buffer is full, remove the oldest token
+    if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
+        const auto old = ctx->prev.front();
 
-        // optimistically check if the candidates are not yet sorted/shuffled/truncated
-        if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
-            cur_p->data[ctx->special_eos_id].logit = -INFINITY;
-        } else {
-            // else, search for the special EOS token
-            for (size_t i = 0; i < cur_p->size; ++i) {
-                if (cur_p->data[i].id == ctx->special_eos_id) {
-                    cur_p->data[i].logit = -INFINITY;
-                    break;
-                }
-            }
+        ctx->token_count[old]--;
+        if (ctx->token_count[old] == 0) {
+            ctx->token_count.erase(old);
         }
     }
 
-    if ((ctx->penalty_last_n == 0) ||
-        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
-        return;
-    }
-
-    bool nl_found = false;
-    size_t nl_idx = 0;
-    float nl_logit = -INFINITY;
-    if (!ctx->penalize_nl) {
-        assert(ctx->linefeed_id >= 0);
+    ctx->prev.push_back(token);
 
-        // optimistically check if the candidates are not yet sorted/shuffled/truncated
-        if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
-            nl_found = true;
-            nl_idx = ctx->linefeed_id;
-            nl_logit = cur_p->data[ctx->linefeed_id].logit;
-        } else {
-            // else, search for the linefeed token
-            for (size_t i = 0; i < cur_p->size; ++i) {
-                if (cur_p->data[i].id == ctx->linefeed_id) {
-                    nl_found = true;
-                    nl_idx = i;
-                    nl_logit = cur_p->data[i].logit;
-                    break;
-                }
-            }
-        }
+#if 0
+    // sanity check
+    std::unordered_map<llama_token, int> tmp;
+    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+        tmp[ctx->prev.rat(i)]++;
     }
 
-    // Create a frequency map to count occurrences of each token in last_tokens
-    // TODO: optimize this by maintaining the token count in the sampler context
-    using llama_token_cnt = std::unordered_map<llama_token, int>;
-    llama_token_cnt token_count;
+    assert(ctx->token_count == tmp);
+#endif
+}
+
+static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
 
-    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
-        token_count[ctx->prev.rat(i)]++;
+    if ((ctx->penalty_last_n == 0) ||
+        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
+        return;
     }
 
     // Apply frequency and presence penalties to the cur_p
     for (size_t i = 0; i < cur_p->size; ++i) {
-        const auto token_iter = token_count.find(cur_p->data[i].id);
-        if (token_iter == token_count.end()) {
+        const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
+        if (token_iter == ctx->token_count.end()) {
             continue;
         }
 
         const int count = token_iter->second;
 
+        assert(count > 0 && count <= ctx->penalty_last_n);
+
         // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
         // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
         if (cur_p->data[i].logit <= 0) {
@@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
     }
 
     cur_p->sorted = false;
-
-    if (!ctx->penalize_nl && nl_found) {
-        // restore the logit of the newline token if it was penalized
-        cur_p->data[nl_idx].logit = nl_logit;
-    }
 }
 
 static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_penalties *) smpl->ctx;
     ctx->prev.clear();
+    ctx->token_count.clear();
 }
 
 static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
     auto * result = llama_sampler_init_penalties(
-            ctx->n_vocab,
-            ctx->special_eos_id,
-            ctx->linefeed_id,
             ctx->penalty_last_n,
             ctx->penalty_repeat,
             ctx->penalty_freq,
-            ctx->penalty_present,
-            ctx->penalize_nl,
-            ctx->ignore_eos);
+            ctx->penalty_present);
 
     // copy the state
     {
@@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
 };
 
 struct llama_sampler * llama_sampler_init_penalties(
-        int32_t n_vocab,
-        llama_token special_eos_id,
-        llama_token linefeed_id,
         int32_t penalty_last_n,
         float penalty_repeat,
         float penalty_freq,
-        float penalty_present,
-        bool penalize_nl,
-        bool ignore_eos) {
-    if (linefeed_id == LLAMA_TOKEN_NULL) {
-        penalize_nl = true;
-    }
-
-    if (special_eos_id == LLAMA_TOKEN_NULL) {
-        ignore_eos = false;
-    }
-
+        float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
     return new llama_sampler {
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
-            /* .n_vocab         = */ n_vocab,
-            /* .special_eos_id  = */ special_eos_id,
-            /* .linefeed_id     = */ linefeed_id,
             /* .penalty_last_n  = */ penalty_last_n,
             /* .penalty_repeat  = */ penalty_repeat,
             /* .penalty_freq    = */ penalty_freq,
             /* .penalty_present = */ penalty_present,
-            /* .penalize_nl     = */ penalize_nl,
-            /* .ignore_eos      = */ ignore_eos,
             /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
+            /* .token_count     = */ {},
         },
     };
 }
@@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
         if (word.find(str) != std::string::npos) {
             token_sequences.emplace(token_id, std::vector<llama_token>());
         } else {
-            size_t word_len = word.size(), str_len = str.size();
+            size_t word_len = word.size();
+            size_t str_len = str.size();
             size_t pos = -1;
             while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
                 bool match = true;
index e5c9e75e4185f0a0d02f579b0a8d04f91603f4a7..c0dcb4848b4031aca0c9fd6c894a54df27d8d0d1 100644 (file)
@@ -145,7 +145,7 @@ static void test_penalties(
     sampler_tester tester(probs, probs_expected);
 
     const size_t n_vocab = probs.size();
-    auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
+    auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
 
     for (size_t i = 0; i < last_tokens.size(); i++) {
         llama_sampler_accept(sampler, last_tokens[i]);