]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : use enums for sampler types (#5418)
authorAlexey Parfenov <redacted>
Sun, 11 Feb 2024 13:43:31 +0000 (13:43 +0000)
committerGitHub <redacted>
Sun, 11 Feb 2024 13:43:31 +0000 (15:43 +0200)
* common: use enums for sampler types

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* minor : spaces

---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
common/sampling.cpp
common/sampling.h

index 9a489a553b60434b262a2f17a0156d2e776ca791..f64da2cb66bb8ac6106c91bab70d5f01a6a1ff86 100644 (file)
@@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 invalid_param = true;
                 break;
             }
-            sparams.samplers_sequence = parse_samplers_input(argv[i]);
+            const auto sampler_names = string_split(argv[i], ';');
+            sparams.samplers_sequence = sampler_types_from_names(sampler_names);
         } else if (arg == "--sampling-seq") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
-            sparams.samplers_sequence = argv[i];
+            sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
         } else if (arg == "--top-p") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -906,6 +907,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
 void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     const llama_sampling_params & sparams = params.sparams;
 
+    std::string sampler_type_chars;
+    std::string sampler_type_names;
+    for (const auto sampler_type : sparams.samplers_sequence) {
+        sampler_type_chars += static_cast<char>(sampler_type);
+        sampler_type_names += sampler_type_to_name_string(sampler_type) + ";";
+    }
+    sampler_type_names.pop_back();
+
     printf("\n");
     printf("usage: %s [options]\n", argv[0]);
     printf("\n");
@@ -947,8 +956,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
     printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
-    printf("  --samplers            samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
-    printf("  --sampling-seq        simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
+    printf("  --samplers            samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
+    printf("  --sampling-seq        simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
     printf("  --min-p N             min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -1097,45 +1106,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
 }
 
 //
-// String parsing
+// String utils
 //
 
-std::string parse_samplers_input(std::string input) {
-    std::string output = "";
+std::vector<std::string> string_split(std::string input, char separator) {
+    std::vector<std::string> parts;
+    size_t separator_pos = input.find(separator);
+    while (separator_pos != std::string::npos) {
+        std::string part = input.substr(0, separator_pos);
+        parts.emplace_back(part);
+        input = input.substr(separator_pos + 1);
+        separator_pos = input.find(separator);
+    }
+    parts.emplace_back(input);
+    return parts;
+}
+
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
     // since samplers names are written multiple ways
     // make it ready for both system names and input names
-    std::unordered_map<std::string, char> samplers_symbols {
-        {"top_k",      'k'},
-        {"top-k",      'k'},
-        {"top_p",      'p'},
-        {"top-p",      'p'},
-        {"nucleus",    'p'},
-        {"typical_p",  'y'},
-        {"typical-p",  'y'},
-        {"typical",    'y'},
-        {"min_p",      'm'},
-        {"min-p",      'm'},
-        {"tfs_z",      'f'},
-        {"tfs-z",      'f'},
-        {"tfs",        'f'},
-        {"temp",       't'},
-        {"temperature",'t'}
+    std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
+        {"top_k",       llama_sampler_type::TOP_K},
+        {"top-k",       llama_sampler_type::TOP_K},
+        {"top_p",       llama_sampler_type::TOP_P},
+        {"top-p",       llama_sampler_type::TOP_P},
+        {"nucleus",     llama_sampler_type::TOP_P},
+        {"typical_p",   llama_sampler_type::TYPICAL_P},
+        {"typical-p",   llama_sampler_type::TYPICAL_P},
+        {"typical",     llama_sampler_type::TYPICAL_P},
+        {"min_p",       llama_sampler_type::MIN_P},
+        {"min-p",       llama_sampler_type::MIN_P},
+        {"tfs_z",       llama_sampler_type::TFS_Z},
+        {"tfs-z",       llama_sampler_type::TFS_Z},
+        {"tfs",         llama_sampler_type::TFS_Z},
+        {"temp",        llama_sampler_type::TEMP},
+        {"temperature", llama_sampler_type::TEMP}
+    };
+
+    std::vector<llama_sampler_type> sampler_types;
+    sampler_types.reserve(names.size());
+    for (const auto& name : names) {
+        const auto sampler_item = sampler_name_map.find(name);
+        if (sampler_item != sampler_name_map.end()) {
+            sampler_types.push_back(sampler_item->second);
+        }
+    }
+    return sampler_types;
+}
+
+std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string) {
+    std::unordered_map<char, llama_sampler_type> sampler_name_map {
+        {'k', llama_sampler_type::TOP_K},
+        {'p', llama_sampler_type::TOP_P},
+        {'y', llama_sampler_type::TYPICAL_P},
+        {'m', llama_sampler_type::MIN_P},
+        {'f', llama_sampler_type::TFS_Z},
+        {'t', llama_sampler_type::TEMP}
     };
-    // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
-    size_t separator = input.find(';');
-    while (separator != input.npos) {
-        std::string name = input.substr(0,separator);
-        input = input.substr(separator+1);
-        separator = input.find(';');
-
-        if (samplers_symbols.find(name) != samplers_symbols.end()) {
-            output += samplers_symbols[name];
+
+    std::vector<llama_sampler_type> sampler_types;
+    sampler_types.reserve(names_string.size());
+    for (const auto & c : names_string) {
+        const auto sampler_item = sampler_name_map.find(c);
+        if (sampler_item != sampler_name_map.end()) {
+            sampler_types.push_back(sampler_item->second);
         }
     }
-    if (samplers_symbols.find(input) != samplers_symbols.end()) {
-        output += samplers_symbols[input];
+    return sampler_types;
+}
+
+std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
+    switch (sampler_type) {
+        case llama_sampler_type::TOP_K:     return "top_k";
+        case llama_sampler_type::TFS_Z:     return "tfs_z";
+        case llama_sampler_type::TYPICAL_P: return "typical_p";
+        case llama_sampler_type::TOP_P:     return "top_p";
+        case llama_sampler_type::MIN_P:     return "min_p";
+        case llama_sampler_type::TEMP:      return "temp";
+        default : return "";
     }
-    return output;
 }
 
 //
index 62de25d6a287c6955c64c7a2d1ffd7ca08509376..9bdd45cf9f84f45c76f8cfa690235374929deecb 100644 (file)
@@ -162,10 +162,13 @@ std::string gpt_random_prompt(std::mt19937 & rng);
 void process_escapes(std::string& input);
 
 //
-// String parsing
+// String utils
 //
 
-std::string parse_samplers_input(std::string input);
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
+std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
+std::vector<std::string> string_split(std::string input, char separator);
+std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
 
 //
 // Model utils
index 82cbdeceabaa1efb0f87e7cf40013e1ed7a776c7..a001750da0ce272678b1c315d12251f493f51c5a 100644 (file)
@@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
 std::string llama_sampling_order_print(const llama_sampling_params & params) {
     std::string result = "CFG -> Penalties ";
     if (params.mirostat == 0) {
-        for (auto s : params.samplers_sequence) {
-            switch (s) {
-                case 'k': result += "-> top_k "; break;
-                case 'f': result += "-> tfs_z "; break;
-                case 'y': result += "-> typical_p "; break;
-                case 'p': result += "-> top_p "; break;
-                case 'm': result += "-> min_p "; break;
-                case 't': result += "-> temp "; break;
-                default : break;
+        for (auto sampler_type : params.samplers_sequence) {
+            const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
+            if (!sampler_type_name.empty()) {
+                result += "-> " + sampler_type_name + " ";
             }
         }
     } else {
@@ -135,16 +130,16 @@ static void sampler_queue(
     const float         min_p             = params.min_p;
     const float         tfs_z             = params.tfs_z;
     const float         typical_p         = params.typical_p;
-    const std::string & samplers_sequence = params.samplers_sequence;
-
-    for (auto s : samplers_sequence) {
-        switch (s){
-            case 'k': llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
-            case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
-            case 'y': llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
-            case 'p': llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
-            case 'm': llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
-            case 't':
+    const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
+
+    for (auto sampler_type : samplers_sequence) {
+        switch (sampler_type) {
+            case llama_sampler_type::TOP_K    : llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
+            case llama_sampler_type::TFS_Z    : llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
+            case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
+            case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
+            case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
+            case llama_sampler_type::TEMP:
                 if (dynatemp_range > 0) {
                     float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
                     float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
index 88899c094866eff2dfb199b7daa3360475758b61..2bd6a75d2153453a05e7b42c288fea00f8b5dc90 100644 (file)
@@ -8,6 +8,16 @@
 #include <vector>
 #include <unordered_map>
 
+// sampler types
+enum class llama_sampler_type : char {
+    TOP_K     = 'k',
+    TOP_P     = 'p',
+    MIN_P     = 'm',
+    TFS_Z     = 'f',
+    TYPICAL_P = 'y',
+    TEMP      = 't'
+};
+
 // sampling parameters
 typedef struct llama_sampling_params {
     int32_t     n_prev                = 64;       // number of previous tokens to remember
@@ -28,7 +38,15 @@ typedef struct llama_sampling_params {
     float       mirostat_tau          = 5.00f;    // target entropy
     float       mirostat_eta          = 0.10f;    // learning rate
     bool        penalize_nl           = true;     // consider newlines as a repeatable token
-    std::string samplers_sequence     = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
+
+    std::vector<llama_sampler_type> samplers_sequence = {
+        llama_sampler_type::TOP_K,
+        llama_sampler_type::TFS_Z,
+        llama_sampler_type::TYPICAL_P,
+        llama_sampler_type::TOP_P,
+        llama_sampler_type::MIN_P,
+        llama_sampler_type::TEMP
+    };
 
     std::string grammar;  // optional BNF-like grammar to constrain sampling