]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: fix incorrectly reported token probabilities (#7125)
authorJohannes Gäßler <redacted>
Tue, 7 May 2024 21:07:58 +0000 (23:07 +0200)
committerGitHub <redacted>
Tue, 7 May 2024 21:07:58 +0000 (23:07 +0200)
* server: normalize token probabilities

* fix temperature == 0.0f

common/sampling.cpp
common/sampling.h
examples/server/README.md
examples/server/server.cpp

index cc83600d9926ee68007e3d28fb92766150a4c143..3715a798531ae615e7c21a37a68ab243d42a6760 100644 (file)
@@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
     result->prev.resize(params.n_prev);
 
+    result->n_considered = 0;
+
     llama_sampling_set_rng_seed(result, params.seed);
 
     return result;
@@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
     ctx->cur.clear();
+    ctx->n_considered = 0;
 }
 
 void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
         }
     }
 
+    ctx_sampling->n_considered = cur_p.size;
+
     return id;
 }
 
index cf7081e3674f10dfdcdfa51ac2850ce177cbba8f..5b73ecdcdb37b60ae9733c148a9ab964f52f8050 100644 (file)
@@ -81,6 +81,7 @@ struct llama_sampling_context {
     // TODO: replace with ring-buffer
     std::vector<llama_token>      prev;
     std::vector<llama_token_data> cur;
+    size_t n_considered;
 
     std::mt19937 rng;
 };
index bf3713640c2bacb3dab7bc1865d88632b6923f8f..a7c3f0b5fc129e1aa9eaa63fdf582c4ff656aa8d 100644 (file)
@@ -272,7 +272,7 @@ node index.js
 
     `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
 
-    `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0`
+    `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
 
     `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`
 
index ff0814b2f28bfb2d58eab6a8644d4e41898f3cd1..85ae1ad9617ebd6bc9cf2828a40288a3aa3648f9 100644 (file)
@@ -2266,17 +2266,31 @@ struct server_context {
                 llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
                 result.tok = id;
 
-                const int32_t n_probs = slot.sparams.n_probs;
-                if (slot.sparams.temp <= 0 && n_probs > 0) {
-                    // for llama_sample_token_greedy we need to sort candidates
-                    llama_sample_softmax(ctx, &cur_p);
-                }
+                const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
+                if (n_probs > 0) {
+                    const size_t n_considered = slot.ctx_sampling->n_considered;
 
-                for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
-                    result.probs.push_back({
-                        cur_p.data[i].id,
-                        cur_p.data[i].p
-                    });
+                    // Make sure at least n_probs top tokens are at the front of the vector:
+                    if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
+                        llama_sample_top_k(ctx, &cur_p, n_probs, 0);
+                    }
+
+                    if (slot.sparams.temp == 0.0f) {
+                        // With greedy sampling the probabilities have possibly not been calculated.
+                        for (size_t i = 0; i < n_probs; ++i) {
+                            result.probs.push_back({
+                                cur_p.data[i].id,
+                                i == 0 ? 1.0f : 0.0f
+                            });
+                        }
+                    } else {
+                        for (size_t i = 0; i < n_probs; ++i) {
+                            result.probs.push_back({
+                                cur_p.data[i].id,
+                                i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
+                            });
+                        }
+                    }
                 }
 
                 if (!process_token(result, slot)) {