]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
samplers : Min-P sampler implementation [alternative to Top P/Top K] (#3841)
authorkalomaze <redacted>
Tue, 31 Oct 2023 19:44:49 +0000 (14:44 -0500)
committerGitHub <redacted>
Tue, 31 Oct 2023 19:44:49 +0000 (20:44 +0100)
* Introduce the new Min-P sampler by @kalomaze
   The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token.

* Min-P enabled and set to 0.05 default

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: cebtenzzre <redacted>
common/common.cpp
common/sampling.cpp
common/sampling.h
examples/main/README.md
llama.cpp
llama.h

index c187128d6ede3d576611d9c568385155072326f4..dc4865e80b1544afbb6222fd991e68fb306d3766 100644 (file)
@@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             sparams.top_p = std::stof(argv[i]);
+        } else if (arg == "--min-p") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            sparams.min_p = std::stof(argv[i]);
         } else if (arg == "--temp") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
+    printf("  --min-p N             min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
     printf("  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
     printf("  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
     printf("  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
@@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
     fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
     fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
+    fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
 }
index c4996c9857d8ac72f103a9d73103205d7101d6e2..673d67a6d5380e1cfe6bcb23fcbfb56d0d002cb7 100644 (file)
@@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
 
     snprintf(result, sizeof(result),
             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
-            "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
+            "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
             params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
-            params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
+            params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
             params.mirostat, params.mirostat_eta, params.mirostat_tau);
 
     return std::string(result);
@@ -110,6 +110,7 @@ llama_token llama_sampling_sample(
     const float   temp            = params.temp;
     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k;
     const float   top_p           = params.top_p;
+    const float   min_p           = params.min_p;
     const float   tfs_z           = params.tfs_z;
     const float   typical_p       = params.typical_p;
     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
@@ -190,6 +191,7 @@ llama_token llama_sampling_sample(
             llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep);
             llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep);
             llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep);
+            llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep);
             llama_sample_temp     (ctx_main, &cur_p, temp);
 
             id = llama_sample_token(ctx_main, &cur_p);
index 62ea6d4cfb7e5e5f71c18f42c8a1c6c89480852b..7c9b8dcf23bcbff62e4d5572511b4243d3322d6d 100644 (file)
@@ -14,6 +14,7 @@ typedef struct llama_sampling_params {
     int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
     int32_t top_k             = 40;    // <= 0 to use vocab size
     float   top_p             = 0.95f; // 1.0 = disabled
+    float   min_p             = 0.05f; // 0.0 = disabled
     float   tfs_z             = 1.00f; // 1.0 = disabled
     float   typical_p         = 1.00f; // 1.0 = disabled
     float   temp              = 0.80f; // 1.0 = disabled
index a9561c383c0cba7873808626cc4114e25dc1865d..a3428b48763d0b1ee6cbaafc6c48633b74ebab0b 100644 (file)
@@ -208,6 +208,14 @@ Top-p sampling, also known as nucleus sampling, is another text generation metho
 
 Example usage: `--top-p 0.95`
 
+### Min P Sampling
+
+-   `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.05).
+
+The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.
+
+Example usage: `--min-p 0.05`
+
 ### Tail Free Sampling (TFS)
 
 -   `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled).
index e599917a81eb1d27b4768c67e0b83bcb3296f9ed..7ee5892989f0a983572f0848a016c2351a666b58 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -7368,6 +7368,32 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
     }
 }
 
+void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    if (p <= 0.0f || !candidates->size) {
+        return;
+    }
+
+    llama_sample_softmax(ctx, candidates);
+
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    float scale = candidates->data[0].p; // scale by max prob
+    size_t i = 1; // first token always matches
+
+    for (; i < candidates->size; ++i) {
+        if (candidates->data[i].p < p * scale && i >= min_keep) {
+            break; // prob too small
+        }
+    }
+
+    // Resize the output vector to keep only the matching tokens
+    candidates->size = i;
+
+    if (ctx) {
+        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    }
+}
+
 void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
     if (z >= 1.0f || candidates->size <= 2) {
         return;
diff --git a/llama.h b/llama.h
index d727dbd9fd915dbe4dd46c6cc6c77c5536e93431..75fe391ef2e733a40d3af651ec0f4b143e4210b0 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -598,6 +598,13 @@ extern "C" {
                            float   p,
                           size_t   min_keep);
 
+    /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
+    LLAMA_API void llama_sample_min_p(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+                           float   p,
+                          size_t   min_keep);
+
     /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
     LLAMA_API void llama_sample_tail_free(
             struct llama_context * ctx,