]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: fix reported top tokens for temperature 0 (#7203)
authorJohannes Gäßler <redacted>
Sat, 11 May 2024 08:11:28 +0000 (10:11 +0200)
committerGitHub <redacted>
Sat, 11 May 2024 08:11:28 +0000 (10:11 +0200)
common/sampling.cpp
common/sampling.h
examples/server/server.cpp

index 3715a798531ae615e7c21a37a68ab243d42a6760..f0f1b92d37f596d8ab0c95b4a3927d1c2792cccf 100644 (file)
@@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
     result->prev.resize(params.n_prev);
 
-    result->n_considered = 0;
+    result->n_valid = 0;
 
     llama_sampling_set_rng_seed(result, params.seed);
 
@@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
     ctx->cur.clear();
-    ctx->n_considered = 0;
+    ctx->n_valid = 0;
 }
 
 void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl(
         }
     }
 
-    ctx_sampling->n_considered = cur_p.size;
+    ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
 
     return id;
 }
index 5b73ecdcdb37b60ae9733c148a9ab964f52f8050..655732ad17206f9eb186a4c74c2e1da811583781 100644 (file)
@@ -81,7 +81,7 @@ struct llama_sampling_context {
     // TODO: replace with ring-buffer
     std::vector<llama_token>      prev;
     std::vector<llama_token_data> cur;
-    size_t n_considered;
+    size_t n_valid; // Number of correct top tokens with correct probabilities.
 
     std::mt19937 rng;
 };
index 305f79492a0552b63a777629abda5f7350f2239a..2bf4026d5084234c739adbcc4cc1f0651a72d40d 100644 (file)
@@ -2270,10 +2270,10 @@ struct server_context {
 
                 const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
                 if (n_probs > 0) {
-                    const size_t n_considered = slot.ctx_sampling->n_considered;
+                    const size_t n_valid = slot.ctx_sampling->n_valid;
 
                     // Make sure at least n_probs top tokens are at the front of the vector:
-                    if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
+                    if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
                         llama_sample_top_k(ctx, &cur_p, n_probs, 0);
                     }
 
@@ -2289,7 +2289,7 @@ struct server_context {
                         for (size_t i = 0; i < n_probs; ++i) {
                             result.probs.push_back({
                                 cur_p.data[i].id,
-                                i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
+                                i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
                             });
                         }
                     }