]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : move random seed generation to the samplers (#9398)
authorslaren <redacted>
Tue, 10 Sep 2024 16:04:25 +0000 (18:04 +0200)
committerGitHub <redacted>
Tue, 10 Sep 2024 16:04:25 +0000 (18:04 +0200)
* llama_sampler_penalties : clamp penalty_last_n to zero

common/arg.cpp
common/sampling.cpp
common/sampling.h
examples/embedding/embedding.cpp
examples/infill/infill.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
examples/server/server.cpp
include/llama.h
src/llama-sampling.cpp

index c5134be515b6ea31ec421cda83b52d71022ecaa3..ca569494f35af1dccf4405450dfa11c7511ec40f 100644 (file)
@@ -173,7 +173,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
     std::string arg;
     const std::string arg_prefix = "--";
     gpt_params & params = ctx_arg.params;
-    gpt_sampler_params & sparams = params.sparams;
 
     std::unordered_map<std::string, llama_arg *> arg_to_options;
     for (auto & opt : ctx_arg.options) {
@@ -283,10 +282,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
         params.kv_overrides.back().key[0] = 0;
     }
 
-    if (sparams.seed == LLAMA_DEFAULT_SEED) {
-        sparams.seed = time(NULL);
-    }
-
     return true;
 }
 
@@ -909,7 +904,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
     ).set_sparam());
     add_opt(llama_arg(
         {"-s", "--seed"}, "SEED",
-        format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
+        format("RNG seed (default: %u, use random seed for %u)", params.sparams.seed, LLAMA_DEFAULT_SEED),
         [](gpt_params & params, const std::string & value) {
             params.sparams.seed = std::stoul(value);
         }
index 21b95646272dfc2ab15bb51cbf41aa0f5e8751e8..4498feb117b7ea81599a06872f88f8f75d0bab9d 100644 (file)
@@ -310,6 +310,10 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
     return cur_p.data[cur_p.selected].id;
 }
 
+uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
+    return llama_sampler_get_seed(gsmpl->chain);
+}
+
 // helpers
 
 llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
index 0a4461fab98505104cceae5ffd47a5b5f49872ea..d0e1a9203e99aa132bb3834bef4fe8f905014563 100644 (file)
@@ -60,6 +60,8 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
 //
 llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 
+uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
+
 // helpers
 
 // access the internal list of current candidate tokens
index da7c7925362af97d90c626065106e02914754eed..db00c636330fc0a092df8bfebbb5a10ac5eeec1f 100644 (file)
@@ -90,8 +90,6 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
-
     llama_backend_init();
     llama_numa_init(params.numa);
 
index 9a527e24468427314d5502c655bd390a779b142a..7e252ce093d759cd9f1de3194f2698438773dd05 100644 (file)
@@ -159,8 +159,6 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
-
     LOG("%s: llama backend init\n", __func__);
     llama_backend_init();
     llama_numa_init(params.numa);
@@ -301,6 +299,9 @@ int main(int argc, char ** argv) {
             LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
         }
     }
+    smpl = gpt_sampler_init(model, sparams);
+
+    LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
     LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     LOG_TEE("\n\n");
@@ -340,8 +341,6 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd;
 
-    smpl = gpt_sampler_init(model, sparams);
-
     while (n_remain != 0 || params.interactive) {
         // predict
         if (!embd.empty()) {
index b986a865a5f6e412d69ad29a60925261124b9493..f41be53082a45ccef88b358ec5615e22648fda21 100644 (file)
@@ -191,8 +191,6 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
-
     LOG("%s: llama backend init\n", __func__);
     llama_backend_init();
     llama_numa_init(params.numa);
@@ -470,8 +468,10 @@ int main(int argc, char ** argv) {
         exit(1);
     }
 
+    LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
     LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
-    LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
+    LOG_TEE("sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
+
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
 
     // group-attention state
index c7d617988b2ed7df4de28aec362714fe3b898674..04df65b0a58922b2a585b2c8ff9d8fd8efe6e496 100644 (file)
@@ -2007,8 +2007,6 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
-
     llama_backend_init();
     llama_numa_init(params.numa);
 
index 7495821f99c323255f0f1cb67a39297f49bd0dbb..5b263f646979bba453406ea3c900bb9def30be42 100644 (file)
@@ -1266,6 +1266,7 @@ struct server_context {
             {"n_predict",                 slot.n_predict},     // Server configured n_predict
             {"model",                     params.model_alias},
             {"seed",                      slot.sparams.seed},
+            {"seed_cur",                  slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
             {"temperature",               slot.sparams.temp},
             {"dynatemp_range",            slot.sparams.dynatemp_range},
             {"dynatemp_exponent",         slot.sparams.dynatemp_exponent},
index 93b3e6e85c485de7eccc9f5c707d5c5c0e3a6a2c..405af912c46868be5fb41e2e21e9439ab5c052a5 100644 (file)
@@ -1127,6 +1127,10 @@ extern "C" {
                              int32_t   n_logit_bias,
               const llama_logit_bias * logit_bias);
 
+
+    // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
+    LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
+
     /// @details Sample and accept a token from the idx-th output of the last evaluation
     //
     // Shorthand for:
index 6f448b80c44c1979809663ab9a3033be0bcf5935..fd1b7f9196f373af1f7a1a3897de3c707f28425d 100644 (file)
@@ -8,6 +8,7 @@
 #include <cstring>
 #include <ctime>
 #include <cfloat>
+#include <chrono>
 #include <cmath>
 #include <numeric>
 #include <random>
@@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
     cur_p->size = k;
 }
 
+static uint32_t get_rng_seed(uint32_t seed) {
+    if (seed == LLAMA_DEFAULT_SEED) {
+        // use system clock if std::random_device is not a true RNG
+        static bool is_rd_prng = std::random_device().entropy() == 0;
+        if (is_rd_prng) {
+            return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
+        }
+        std::random_device rd;
+        return rd();
+    }
+    return seed;
+}
+
 // llama_sampler API
 
 const char * llama_sampler_name(const struct llama_sampler * smpl) {
@@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
 
 struct llama_sampler_dist {
     const uint32_t seed;
+          uint32_t seed_cur;
 
     std::mt19937 rng;
 };
@@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
 
 static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
-    ctx->rng = std::mt19937(ctx->seed);
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
 }
 
 static void llama_sampler_dist_free(struct llama_sampler * smpl) {
@@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
 };
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
+    auto seed_cur = get_rng_seed(seed);
     return new llama_sampler {
         /* .iface = */ &llama_sampler_dist_i,
         /* .ctx   = */ new llama_sampler_dist {
-            /* .seed = */ seed,
-            /* .rng  = */ std::mt19937(seed),
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .rng      = */ std::mt19937(seed_cur),
         },
     };
 }
@@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
     const int32_t n_vocab;
 
     const uint32_t seed;
+          uint32_t seed_cur;
 
     const float tau;
     const float eta;
@@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
 static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
     ctx->mu = 2.0f*ctx->tau;
-    ctx->rng = std::mt19937(ctx->seed);
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
 }
 
 static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
@@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
 };
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
+    auto seed_cur = get_rng_seed(seed);
     return new llama_sampler {
         /* .iface = */ &llama_sampler_mirostat_i,
         /* .ctx   = */ new llama_sampler_mirostat {
-            /* .n_vocab = */ n_vocab,
-            /* .seed    = */ seed,
-            /* .tau     = */ tau,
-            /* .eta     = */ eta,
-            /* .m       = */ m,
-            /* .mu      = */ 2.0f*tau,
-            /* .rng     = */ std::mt19937(seed),
+            /* .n_vocab  = */ n_vocab,
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .tau      = */ tau,
+            /* .eta      = */ eta,
+            /* .m        = */ m,
+            /* .mu       = */ 2.0f*tau,
+            /* .rng      = */ std::mt19937(seed_cur),
         },
     };
 }
@@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
 
 struct llama_sampler_mirostat_v2 {
     const uint32_t seed;
+          uint32_t seed_cur;
 
     const float tau;
     const float eta;
@@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
 static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
     ctx->mu = 2.0f*ctx->tau;
-    ctx->rng = std::mt19937(ctx->seed);
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
 }
 
 static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
@@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
 };
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
+    auto seed_cur = get_rng_seed(seed);
     return new llama_sampler {
         /* .iface = */ &llama_sampler_mirostat_v2_i,
         /* .ctx   = */ new llama_sampler_mirostat_v2 {
-            /* .seed  = */ seed,
-            /* .tau   = */ tau,
-            /* .eta   = */ eta,
-            /* .mu    = */ 2.0f*tau,
-            /* .rng   = */ std::mt19937(seed),
+            /* .seed     = */ seed,
+            /* .seed_cur = */ seed_cur,
+            /* .tau      = */ tau,
+            /* .eta      = */ eta,
+            /* .mu       = */ 2.0f*tau,
+            /* .rng      = */ std::mt19937(seed_cur),
         },
     };
 }
@@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
         ignore_eos = false;
     }
 
+    penalty_last_n = std::max(penalty_last_n, 0);
+
     return new llama_sampler {
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
@@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
         }
     }
 }
+
 static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
     return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
@@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
         },
     };
 }
+
+// utils
+
+uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
+    if (smpl->iface == &llama_sampler_dist_i) {
+        return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
+    }
+
+    if (smpl->iface == &llama_sampler_mirostat_i) {
+        return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
+    }
+
+    if (smpl->iface == &llama_sampler_mirostat_v2_i) {
+        return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
+    }
+
+    if (smpl->iface == &llama_sampler_chain_i) {
+        const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
+        for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
+            const uint32_t seed = llama_sampler_get_seed(*it);
+            if (seed != LLAMA_DEFAULT_SEED) {
+                return seed;
+            }
+        }
+    }
+
+    return LLAMA_DEFAULT_SEED;
+}