]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Server: fix seed for multiple slots (#6835)
authorJohannes Gäßler <redacted>
Wed, 24 Apr 2024 09:08:36 +0000 (11:08 +0200)
committerGitHub <redacted>
Wed, 24 Apr 2024 09:08:36 +0000 (11:08 +0200)
* Server: add tests for consistent results

* sampling: separate rng per sampling context

common/common.cpp
common/sampling.cpp
common/sampling.h
examples/lookup/lookup-stats.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/server/server.cpp
examples/server/tests/features/results.feature [new file with mode: 0644]
examples/server/tests/features/steps/steps.py
llama.cpp
llama.h

index 06f252ea6914b9779d032401c4c08e2e7b6316a9..a0d1f8d59cb62ff13692fbfbaab351d0c3684864 100644 (file)
@@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
             invalid_param = true;
             return true;
         }
+        // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
         params.seed = std::stoul(argv[i]);
+        sparams.seed = std::stoul(argv[i]);
         return true;
     }
     if (arg == "-t" || arg == "--threads") {
index 45d68b26c2b93f5e007617ad1499c305f04b2bf3..f2466550168a7781dbf0106fcfabef1c1e0757be 100644 (file)
@@ -1,4 +1,6 @@
+#define LLAMA_API_INTERNAL
 #include "sampling.h"
+#include <random>
 
 struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
     struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
     result->prev.resize(params.n_prev);
 
+    llama_sampling_set_rng_seed(result, params.seed);
+
     return result;
 }
 
@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
     ctx->cur.clear();
 }
 
+void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
+    if (seed == LLAMA_DEFAULT_SEED) {
+        seed = time(NULL);
+    }
+    ctx->rng.seed(seed);
+}
+
 void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
     if (dst->grammar) {
         llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
 
             sampler_queue(ctx_main, params, cur_p, min_keep);
 
-            id = llama_sample_token(ctx_main, &cur_p);
+            id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
 
             //{
             //    const int n_top = 10;
index 639b819ab4fb2c105c4214f17c59c5214fbcc2ef..cf7081e3674f10dfdcdfa51ac2850ce177cbba8f 100644 (file)
@@ -4,9 +4,10 @@
 
 #include "grammar-parser.h"
 
+#include <random>
 #include <string>
-#include <vector>
 #include <unordered_map>
+#include <vector>
 
 // sampler types
 enum class llama_sampler_type : char {
@@ -20,25 +21,26 @@ enum class llama_sampler_type : char {
 
 // sampling parameters
 typedef struct llama_sampling_params {
-    int32_t     n_prev                = 64;       // number of previous tokens to remember
-    int32_t     n_probs               = 0;        // if greater than 0, output the probabilities of top n_probs tokens.
-    int32_t     min_keep              = 0;        // 0 = disabled, otherwise samplers should return at least min_keep tokens
-    int32_t     top_k                 = 40;       // <= 0 to use vocab size
-    float       top_p                 = 0.95f;    // 1.0 = disabled
-    float       min_p                 = 0.05f;    // 0.0 = disabled
-    float       tfs_z                 = 1.00f;    // 1.0 = disabled
-    float       typical_p             = 1.00f;    // 1.0 = disabled
-    float       temp                  = 0.80f;    // <= 0.0 to sample greedily, 0.0 to not output probabilities
-    float       dynatemp_range        = 0.00f;    // 0.0 = disabled
-    float       dynatemp_exponent     = 1.00f;    // controls how entropy maps to temperature in dynamic temperature sampler
-    int32_t     penalty_last_n        = 64;       // last n tokens to penalize (0 = disable penalty, -1 = context size)
-    float       penalty_repeat        = 1.00f;    // 1.0 = disabled
-    float       penalty_freq          = 0.00f;    // 0.0 = disabled
-    float       penalty_present       = 0.00f;    // 0.0 = disabled
-    int32_t     mirostat              = 0;        // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
-    float       mirostat_tau          = 5.00f;    // target entropy
-    float       mirostat_eta          = 0.10f;    // learning rate
-    bool        penalize_nl           = false;     // consider newlines as a repeatable token
+    int32_t     n_prev                = 64;                 // number of previous tokens to remember
+    int32_t     n_probs               = 0;                  // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t     min_keep              = 0;                  // 0 = disabled, otherwise samplers should return at least min_keep tokens
+    int32_t     top_k                 = 40;                 // <= 0 to use vocab size
+    float       top_p                 = 0.95f;              // 1.0 = disabled
+    float       min_p                 = 0.05f;              // 0.0 = disabled
+    float       tfs_z                 = 1.00f;              // 1.0 = disabled
+    float       typical_p             = 1.00f;              // 1.0 = disabled
+    float       temp                  = 0.80f;              // <= 0.0 to sample greedily, 0.0 to not output probabilities
+    float       dynatemp_range        = 0.00f;              // 0.0 = disabled
+    float       dynatemp_exponent     = 1.00f;              // controls how entropy maps to temperature in dynamic temperature sampler
+    int32_t     penalty_last_n        = 64;                 // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float       penalty_repeat        = 1.00f;              // 1.0 = disabled
+    float       penalty_freq          = 0.00f;              // 0.0 = disabled
+    float       penalty_present       = 0.00f;              // 0.0 = disabled
+    int32_t     mirostat              = 0;                  // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float       mirostat_tau          = 5.00f;              // target entropy
+    float       mirostat_eta          = 0.10f;              // learning rate
+    bool        penalize_nl           = false;              // consider newlines as a repeatable token
+    uint32_t    seed                  = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
 
     std::vector<llama_sampler_type> samplers_sequence = {
         llama_sampler_type::TOP_K,
@@ -79,6 +81,8 @@ struct llama_sampling_context {
     // TODO: replace with ring-buffer
     std::vector<llama_token>      prev;
     std::vector<llama_token_data> cur;
+
+    std::mt19937 rng;
 };
 
 #include "common.h"
@@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
 // - reset grammar
 void llama_sampling_reset(llama_sampling_context * ctx);
 
+// Set the sampler seed
+void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
+
 // Copy the sampler context
 void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
 
index 41b62c2fe9f76b4966717ae5fe665aa9571cbc83..87ecc0a4f1394e615971cb9784edd005b3f726dc 100644 (file)
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
 
     // load the model
     std::tie(model, ctx) = llama_init_from_gpt_params(params);
-    llama_set_rng_seed(ctx, params.seed);
     GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
 
     // tokenize the prompt
index 9526e898fe7638218ea7577203bf65dec3133b43..eebbd00a58e66cb2ffa023c2e1692cf1e05e5843 100644 (file)
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
 
     // load the model
     std::tie(model, ctx) = llama_init_from_gpt_params(params);
-    llama_set_rng_seed(ctx, params.seed);
     GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
 
     // tokenize the prompt
index 1180734b9760d2cf4021fdba8fbc98bcf8a63ca3..a74d4d9c72364a752a8a6b99d87b8847fdb10d53 100644 (file)
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
                 return 1;
             }
             session_tokens.resize(n_token_count_out);
-            llama_set_rng_seed(ctx, params.seed);
             LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
         }
     }
index 25bc29639677251df04869c72c1c27661e565c26..68c63f9f1b1a04c055580525c8392109db19c710 100644 (file)
@@ -854,7 +854,7 @@ struct server_context {
         slot.sparams.penalize_nl       = json_value(data, "penalize_nl",       default_sparams.penalize_nl);
         slot.params.n_keep             = json_value(data, "n_keep",            slot.params.n_keep);
         slot.params.n_discard          = json_value(data, "n_discard",         default_params.n_discard);
-        slot.params.seed               = json_value(data, "seed",              default_params.seed);
+        slot.sparams.seed              = json_value(data, "seed",              default_sparams.seed);
         slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
         slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
 
@@ -1028,7 +1028,6 @@ struct server_context {
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
-            llama_set_rng_seed(ctx, slot.params.seed);
         }
 
         slot.command = SLOT_COMMAND_LOAD_PROMPT;
diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature
new file mode 100644 (file)
index 0000000..f17120f
--- /dev/null
@@ -0,0 +1,57 @@
+@llama.cpp
+@results
+Feature: Results
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
+    And   a model file test-model-00001-of-00003.gguf
+    And   128 as batch size
+    And   256 KV cache size
+    And   128 max tokens to predict
+
+  Scenario Outline: Multi users completion
+    Given <n_slots> slots
+    And   continuous batching
+    Then  the server is starting
+    Then  the server is healthy
+
+    Given 42 as seed
+    And a prompt:
+      """
+      Write a very long story about AI.
+      """
+
+    Given 42 as seed
+    And a prompt:
+      """
+      Write a very long story about AI.
+      """
+
+    Given 42 as seed
+    And a prompt:
+      """
+      Write a very long story about AI.
+      """
+
+    Given 42 as seed
+    And a prompt:
+      """
+      Write a very long story about AI.
+      """
+
+    Given 42 as seed
+    And a prompt:
+      """
+      Write a very long story about AI.
+      """
+
+    Given concurrent completion requests
+    Then the server is busy
+    Then the server is idle
+    And  all slots are idle
+    Then all predictions are equal
+    Examples:
+      | n_slots |
+      | 1       |
+      | 2       |
index ca400efa41b9e58bf033b77a26255854208f8499..f71e0d706cca9d8a5c990908b923ffdb1d1e96aa 100644 (file)
@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
     context.server_metrics = False
     context.server_process = None
     context.seed = None
+    context.draft = None
     context.server_seed = None
     context.user_api_key = None
     context.response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
     context.n_gpu_layer = ngl
 
 
+@step('{draft:d} as draft')
+def step_draft(context, draft):
+    context.draft = draft
+
+
 @step('{n_ctx:d} KV cache size')
 def step_n_ctx(context, n_ctx):
     context.n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
     assert_n_tokens_predicted(context.completion, predicted_n)
 
 
+@step('all predictions are equal')
+@async_run_until_complete
+async def step_predictions_equal(context):
+    n_completions = await gather_tasks_results(context)
+    assert n_completions >= 2, "need at least 2 completions"
+    assert_all_predictions_equal(context.tasks_result)
+    context.tasks_result = []
+
+
 @step('the completion is  truncated')
 def step_assert_completion_truncated(context):
     step_assert_completion_truncated(context, '')
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
                                                      f' {n_predicted} <> {expected_predicted_n}')
 
+def assert_all_predictions_equal(completion_responses):
+    content_0 = completion_responses[0]['content']
+
+    if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+        print(f"content 0: {content_0}")
+
+    i = 1
+    for response in completion_responses[1:]:
+        content = response['content']
+
+        if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+            print(f"content {i}: {content}")
+
+        assert content == content_0, "contents not equal"
+
+        i += 1
+
 
 async def gather_tasks_results(context):
     n_tasks = len(context.concurrent_tasks)
@@ -1148,6 +1180,8 @@ def start_server_background(context):
         server_args.extend(['--ubatch-size', context.n_ubatch])
     if context.n_gpu_layer:
         server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
+    if context.draft is not None:
+        server_args.extend(['--draft', context.draft])
     if context.server_continuous_batching:
         server_args.append('--cont-batching')
     if context.server_embeddings:
index e4ca34bd13389cca55b4db9893f3a25e4406565e..3a4a03d8f29fb0be9b832c9f03442a089f939564 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
     return result;
 }
 
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
     GGML_ASSERT(ctx);
 
     const int64_t t_start_sample_us = ggml_time_us();
@@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
     }
 
     std::discrete_distribution<> dist(probs.begin(), probs.end());
-    auto & rng = ctx->rng;
     int idx = dist(rng);
 
     llama_token result = candidates->data[idx].id;
@@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
     return result;
 }
 
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+    return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
+}
+
 void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
     const int64_t t_start_sample_us = ggml_time_us();
 
diff --git a/llama.h b/llama.h
index 4effca42cc65de9176110abc04428d28a6da68da..7bfd13740cf25f157e3976e3747e69cb2e150c6c 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -987,7 +987,7 @@ extern "C" {
             struct llama_context * ctx,
           llama_token_data_array * candidates);
 
-    /// @details Randomly selects a token from the candidates based on their probabilities.
+    /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
     LLAMA_API llama_token llama_sample_token(
             struct llama_context * ctx,
           llama_token_data_array * candidates);
@@ -1074,8 +1074,9 @@ extern "C" {
 // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
 #ifdef LLAMA_API_INTERNAL
 
-#include <vector>
+#include <random>
 #include <string>
+#include <vector>
 
 struct ggml_tensor;
 
@@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const std::string & src,
         llama_partial_utf8   partial_start);
 
+// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
+// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
+llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
+
 #endif // LLAMA_API_INTERNAL
 
 #endif // LLAMA_H