]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add "samplers" param to control the samplers order (#5494)
authorAlexey Parfenov <redacted>
Fri, 16 Feb 2024 11:33:25 +0000 (11:33 +0000)
committerGitHub <redacted>
Fri, 16 Feb 2024 11:33:25 +0000 (13:33 +0200)
common/common.cpp
common/common.h
common/sampling.cpp
common/sampling.h
examples/server/README.md
examples/server/server.cpp

index c5e83cc2a9e40f1f903ff2342dc4ff4cf5654390..3a92d3797492f8fc2b4f394ed01c0e51c080cee0 100644 (file)
@@ -341,7 +341,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             const auto sampler_names = string_split(argv[i], ';');
-            sparams.samplers_sequence = sampler_types_from_names(sampler_names);
+            sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
         } else if (arg == "--sampling-seq") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -964,7 +964,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
     printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
-    printf("  --samplers            samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
+    printf("  --samplers            samplers that will be used for generation in the order, separated by \';\'\n");
+    printf("                        (default: %s)\n", sampler_type_names.c_str());
     printf("  --sampling-seq        simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
@@ -1133,34 +1134,50 @@ std::vector<std::string> string_split(std::string input, char separator) {
     return parts;
 }
 
-std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+    std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
+        {"top_k",       llama_sampler_type::TOP_K},
+        {"top_p",       llama_sampler_type::TOP_P},
+        {"typical_p",   llama_sampler_type::TYPICAL_P},
+        {"min_p",       llama_sampler_type::MIN_P},
+        {"tfs_z",       llama_sampler_type::TFS_Z},
+        {"temperature", llama_sampler_type::TEMPERATURE}
+    };
+
     // since samplers names are written multiple ways
     // make it ready for both system names and input names
-    std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
-        {"top_k",       llama_sampler_type::TOP_K},
+    std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
         {"top-k",       llama_sampler_type::TOP_K},
-        {"top_p",       llama_sampler_type::TOP_P},
         {"top-p",       llama_sampler_type::TOP_P},
         {"nucleus",     llama_sampler_type::TOP_P},
-        {"typical_p",   llama_sampler_type::TYPICAL_P},
         {"typical-p",   llama_sampler_type::TYPICAL_P},
         {"typical",     llama_sampler_type::TYPICAL_P},
-        {"min_p",       llama_sampler_type::MIN_P},
         {"min-p",       llama_sampler_type::MIN_P},
-        {"tfs_z",       llama_sampler_type::TFS_Z},
         {"tfs-z",       llama_sampler_type::TFS_Z},
         {"tfs",         llama_sampler_type::TFS_Z},
-        {"temp",        llama_sampler_type::TEMP},
-        {"temperature", llama_sampler_type::TEMP}
+        {"temp",        llama_sampler_type::TEMPERATURE}
     };
 
     std::vector<llama_sampler_type> sampler_types;
     sampler_types.reserve(names.size());
-    for (const auto& name : names) {
-        const auto sampler_item = sampler_name_map.find(name);
-        if (sampler_item != sampler_name_map.end()) {
+    for (const auto & name : names)
+    {
+        auto sampler_item = sampler_canonical_name_map.find(name);
+        if (sampler_item != sampler_canonical_name_map.end())
+        {
             sampler_types.push_back(sampler_item->second);
         }
+        else
+        {
+            if (allow_alt_names)
+            {
+                sampler_item = sampler_alt_name_map.find(name);
+                if (sampler_item != sampler_alt_name_map.end())
+                {
+                    sampler_types.push_back(sampler_item->second);
+                }
+            }
+        }
     }
     return sampler_types;
 }
@@ -1172,7 +1189,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
         {'y', llama_sampler_type::TYPICAL_P},
         {'m', llama_sampler_type::MIN_P},
         {'f', llama_sampler_type::TFS_Z},
-        {'t', llama_sampler_type::TEMP}
+        {'t', llama_sampler_type::TEMPERATURE}
     };
 
     std::vector<llama_sampler_type> sampler_types;
@@ -1188,12 +1205,12 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
 
 std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
     switch (sampler_type) {
-        case llama_sampler_type::TOP_K:     return "top_k";
-        case llama_sampler_type::TFS_Z:     return "tfs_z";
-        case llama_sampler_type::TYPICAL_P: return "typical_p";
-        case llama_sampler_type::TOP_P:     return "top_p";
-        case llama_sampler_type::MIN_P:     return "min_p";
-        case llama_sampler_type::TEMP:      return "temp";
+        case llama_sampler_type::TOP_K:       return "top_k";
+        case llama_sampler_type::TFS_Z:       return "tfs_z";
+        case llama_sampler_type::TYPICAL_P:   return "typical_p";
+        case llama_sampler_type::TOP_P:       return "top_p";
+        case llama_sampler_type::MIN_P:       return "min_p";
+        case llama_sampler_type::TEMPERATURE: return "temperature";
         default : return "";
     }
 }
index 74c1369953d480c5c73360817f4fc82b4f8e800a..935771d44ca9caee77f7f37f0cc8fcd1b9eec8d4 100644 (file)
@@ -165,7 +165,7 @@ void process_escapes(std::string& input);
 // String utils
 //
 
-std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
 std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
 std::vector<std::string> string_split(std::string input, char separator);
 std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
index a001750da0ce272678b1c315d12251f493f51c5a..53013138a9eb48e9c86156c653d2f45d34f0128c 100644 (file)
@@ -139,7 +139,7 @@ static void sampler_queue(
             case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
             case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
             case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
-            case llama_sampler_type::TEMP:
+            case llama_sampler_type::TEMPERATURE:
                 if (dynatemp_range > 0) {
                     float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
                     float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
index 2bd6a75d2153453a05e7b42c288fea00f8b5dc90..e1279a8941ce0dff72b018028b938dcdd68b3ea5 100644 (file)
 
 // sampler types
 enum class llama_sampler_type : char {
-    TOP_K     = 'k',
-    TOP_P     = 'p',
-    MIN_P     = 'm',
-    TFS_Z     = 'f',
-    TYPICAL_P = 'y',
-    TEMP      = 't'
+    TOP_K       = 'k',
+    TOP_P       = 'p',
+    MIN_P       = 'm',
+    TFS_Z       = 'f',
+    TYPICAL_P   = 'y',
+    TEMPERATURE = 't'
 };
 
 // sampling parameters
@@ -45,7 +45,7 @@ typedef struct llama_sampling_params {
         llama_sampler_type::TYPICAL_P,
         llama_sampler_type::TOP_P,
         llama_sampler_type::MIN_P,
-        llama_sampler_type::TEMP
+        llama_sampler_type::TEMPERATURE
     };
 
     std::string grammar;  // optional BNF-like grammar to constrain sampling
index 8e141d22d1716d6a28f5522eb0409f025f571471..249368749ff07b2e8503b900eda4f1df39d1a6f3 100644 (file)
@@ -204,6 +204,8 @@ node index.js
 
     `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
 
+    `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. (default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values)
+
 ### Result JSON
 
 - Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion.
index 0cb802ce851adbbbeff324042519dc4e8df68d3b..a0b46970b83a9e901aa394f1b549060f3876efb1 100644 (file)
@@ -672,6 +672,24 @@ struct llama_server_context
             }
         }
 
+        const auto &samplers_sequence = data.find("samplers");
+        if (samplers_sequence != data.end() && samplers_sequence->is_array())
+        {
+            std::vector<std::string> sampler_names;
+            for (const auto &sampler_name : *samplers_sequence)
+            {
+                if (sampler_name.is_string())
+                {
+                    sampler_names.emplace_back(sampler_name);
+                }
+            }
+            slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
+        }
+        else
+        {
+            slot->sparams.samplers_sequence = default_sparams.samplers_sequence;
+        }
+
         if (multimodal)
         {
             const auto &images_data = data.find("image_data");
@@ -1026,6 +1044,12 @@ struct llama_server_context
         const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
         const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
                                 eos_bias->second < 0.0f && std::isinf(eos_bias->second);
+        std::vector<std::string> samplers_sequence;
+        for (const auto &sampler_type : slot.sparams.samplers_sequence)
+        {
+            samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
+        }
+
         return json {
             {"n_ctx",             slot.n_ctx},
             {"model",             params.model_alias},
@@ -1056,6 +1080,7 @@ struct llama_server_context
             {"logit_bias",        slot.sparams.logit_bias},
             {"n_probs",           slot.sparams.n_probs},
             {"grammar",           slot.sparams.grammar},
+            {"samplers",          samplers_sequence}
         };
     }