]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : custom samplers order (#4285)
authorMaggotHATE <redacted>
Tue, 5 Dec 2023 10:05:51 +0000 (15:05 +0500)
committerGitHub <redacted>
Tue, 5 Dec 2023 10:05:51 +0000 (12:05 +0200)
* Samplers sequence order w parameter

* Cleaned commented code

* Fixed formatting

* Rewrote with unordered_map

* Revert and rewrite, too many problems and safeguards would be needed

* Fixed code style

* Code style fixes according to review

* More readable samplers input string, fixed help

* Style fix in sampler_queue

* Formatting fixes

* Fixing whitespaces

common/common.cpp
common/common.h
common/sampling.cpp
common/sampling.h
examples/main/main.cpp

index 1dcc235eac0e6b4d78f57b1f0051e1d6eb92ecef..8e6d74d0d704a3cb2fb3b8986c69171171a99d80 100644 (file)
@@ -280,6 +280,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             params.yarn_beta_slow = std::stof(argv[i]);
         } else if (arg == "--memory-f32") {
             params.memory_f16 = false;
+        } else if (arg == "--samplers") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            sparams.samplers_sequence = parse_samplers_input(argv[i]);
+        } else if (arg == "--sampling-seq") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            sparams.samplers_sequence = argv[i];
         } else if (arg == "--top-p") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -761,6 +773,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
     printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    printf("  --samplers            samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
+    printf("  --sampling-seq        simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
     printf("  --min-p N             min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -886,6 +900,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
     GGML_UNREACHABLE();
 }
 
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input) {
+    std::string output = "";
+    // since samplers names are written multiple ways
+    // make it ready for both system names and input names
+    std::unordered_map<std::string, char> samplers_symbols {
+        {"top_k",      'k'},
+        {"top-k",      'k'},
+        {"top_p",      'p'},
+        {"top-p",      'p'},
+        {"nucleus",    'p'},
+        {"typical_p",  'y'},
+        {"typical-p",  'y'},
+        {"typical",    'y'},
+        {"min_p",      'm'},
+        {"min-p",      'm'},
+        {"tfs_z",      'f'},
+        {"tfs-z",      'f'},
+        {"tfs",        'f'},
+        {"temp",       't'},
+        {"temperature",'t'}
+    };
+    // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
+    size_t separator = input.find(';');
+    while (separator != input.npos) {
+        std::string name = input.substr(0,separator);
+        input = input.substr(separator+1);
+        separator = input.find(';');
+
+        if (samplers_symbols.find(name) != samplers_symbols.end()) {
+            output += samplers_symbols[name];
+        }
+    }
+    if (samplers_symbols.find(input) != samplers_symbols.end()) {
+        output += samplers_symbols[input];
+    }
+    return output;
+}
+
 //
 // Model utils
 //
index 2f6fe48ab53d3527df02f2589e7ca3d63d4117c8..534f7b1322da2adaa0a88b835e95a385af7065e1 100644 (file)
@@ -141,6 +141,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
 
 void process_escapes(std::string& input);
 
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input);
+
 //
 // Model utils
 //
index 1317024c2c11cffd9d05a8d8347b86f9064f3981..b6bb886c6c7d765b4a433f31c6e2035acd8a87c1 100644 (file)
@@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
     return std::string(result);
 }
 
+std::string llama_sampling_order_print(const llama_sampling_params & params) {
+    std::string result = "CFG -> Penalties ";
+    if (params.mirostat == 0) {
+        for (auto s : params.samplers_sequence) {
+            switch (s) {
+                case 'k': result += "-> top_k "; break;
+                case 'f': result += "-> tfs_z "; break;
+                case 'y': result += "-> typical_p "; break;
+                case 'p': result += "-> top_p "; break;
+                case 'm': result += "-> min_p "; break;
+                case 't': result += "-> temp "; break;
+                default : break;
+            }
+        }
+    } else result += "-> mirostat ";
+
+    return result;
+}
+
+// no reasons to expose this function in header
+void sampler_queue(
+                   struct llama_context * ctx_main,
+            const llama_sampling_params & params,
+                 llama_token_data_array & cur_p,
+                                 size_t & min_keep) {
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+    const float         temp              = params.temp;
+    const int32_t       top_k             = params.top_k <= 0 ? n_vocab : params.top_k;
+    const float         top_p             = params.top_p;
+    const float         min_p             = params.min_p;
+    const float         tfs_z             = params.tfs_z;
+    const float         typical_p         = params.typical_p;
+    const std::string & samplers_sequence = params.samplers_sequence;
+
+    for (auto s : samplers_sequence) {
+        switch (s){
+            case 'k': llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
+            case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
+            case 'y': llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
+            case 'p': llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
+            case 'm': llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
+            case 't': llama_sample_temp     (ctx_main, &cur_p, temp); break;
+            default : break;
+        }
+    }
+}
+
 llama_token llama_sampling_sample(
                   struct llama_sampling_context * ctx_sampling,
                   struct llama_context * ctx_main,
@@ -109,11 +157,6 @@ llama_token llama_sampling_sample(
     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
 
     const float   temp            = params.temp;
-    const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k;
-    const float   top_p           = params.top_p;
-    const float   min_p           = params.min_p;
-    const float   tfs_z           = params.tfs_z;
-    const float   typical_p       = params.typical_p;
     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
     const float   penalty_repeat  = params.penalty_repeat;
     const float   penalty_freq    = params.penalty_freq;
@@ -188,12 +231,7 @@ llama_token llama_sampling_sample(
             // temperature sampling
             size_t min_keep = std::max(1, params.n_probs);
 
-            llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep);
-            llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep);
-            llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep);
-            llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep);
-            llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep);
-            llama_sample_temp     (ctx_main, &cur_p, temp);
+            sampler_queue(ctx_main, params, cur_p, min_keep);
 
             id = llama_sample_token(ctx_main, &cur_p);
 
index 7c9b8dcf23bcbff62e4d5572511b4243d3322d6d..fdfa9eed1467b1fba8165e1f4034d5f87537bd22 100644 (file)
 
 // sampling parameters
 typedef struct llama_sampling_params {
-    int32_t n_prev            = 64;    // number of previous tokens to remember
-    int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
-    int32_t top_k             = 40;    // <= 0 to use vocab size
-    float   top_p             = 0.95f; // 1.0 = disabled
-    float   min_p             = 0.05f; // 0.0 = disabled
-    float   tfs_z             = 1.00f; // 1.0 = disabled
-    float   typical_p         = 1.00f; // 1.0 = disabled
-    float   temp              = 0.80f; // 1.0 = disabled
-    int32_t penalty_last_n    = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
-    float   penalty_repeat    = 1.10f; // 1.0 = disabled
-    float   penalty_freq      = 0.00f; // 0.0 = disabled
-    float   penalty_present   = 0.00f; // 0.0 = disabled
-    int32_t mirostat          = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
-    float   mirostat_tau      = 5.00f; // target entropy
-    float   mirostat_eta      = 0.10f; // learning rate
-    bool    penalize_nl       = true;  // consider newlines as a repeatable token
+    int32_t     n_prev                = 64;       // number of previous tokens to remember
+    int32_t     n_probs               = 0;        // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t     top_k                 = 40;       // <= 0 to use vocab size
+    float       top_p                 = 0.95f;    // 1.0 = disabled
+    float       min_p                 = 0.05f;    // 0.0 = disabled
+    float       tfs_z                 = 1.00f;    // 1.0 = disabled
+    float       typical_p             = 1.00f;    // 1.0 = disabled
+    float       temp                  = 0.80f;    // 1.0 = disabled
+    int32_t     penalty_last_n        = 64;       // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float       penalty_repeat        = 1.10f;    // 1.0 = disabled
+    float       penalty_freq          = 0.00f;    // 0.0 = disabled
+    float       penalty_present       = 0.00f;    // 0.0 = disabled
+    int32_t     mirostat              = 0;        // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float       mirostat_tau          = 5.00f;    // target entropy
+    float       mirostat_eta          = 0.10f;    // learning rate
+    bool        penalize_nl           = true;     // consider newlines as a repeatable token
+    std::string samplers_sequence     = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
 
     std::string grammar;  // optional BNF-like grammar to constrain sampling
 
@@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
 // Print sampling parameters into a string
 std::string llama_sampling_print(const llama_sampling_params & params);
 
+// Print sampling order into a string
+std::string llama_sampling_order_print(const llama_sampling_params & params);
+
 // this is a common sampling function used across the examples for convenience
 // it can serve as a starting point for implementing your own sampling function
 // Note: When using multiple sequences, it is the caller's responsibility to call
index c5cdfbf21b9547e2bdde8a21e58836729039a2a9..c096f110b32c55f19959dba82820dc36c38369c8 100644 (file)
@@ -437,6 +437,7 @@ int main(int argc, char ** argv) {
         }
     }
     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
+    LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     LOG_TEE("\n\n");