]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add option for greedy sampling with probs (#3813)
authorGeorgi Gerganov <redacted>
Sat, 28 Oct 2023 11:23:11 +0000 (14:23 +0300)
committerGitHub <redacted>
Sat, 28 Oct 2023 11:23:11 +0000 (14:23 +0300)
* llama : add option for greedy sampling with probs

* llama : add comment about llama_sample_token_greedy() missing probs

* sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs

common/common.cpp
common/sampling.cpp
examples/speculative/speculative.cpp
llama.h

index c0d4924e2d4a5331107c7d43783ccdc8ec264f51..f81f4d354bc01755c0de2af0097e28888ae0a3c1 100644 (file)
@@ -224,6 +224,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             sparams.temp = std::stof(argv[i]);
+            sparams.temp = std::max(sparams.temp, 0.0f);
         } else if (arg == "--tfs") {
             if (++i >= argc) {
                 invalid_param = true;
index 5258d4e8263693811b0f56c9165ea0a2fefc0018..c4996c9857d8ac72f103a9d73103205d7101d6e2 100644 (file)
@@ -167,8 +167,12 @@ llama_token llama_sampling_sample(
         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
     }
 
-    if (temp <= 0) {
-        // greedy sampling
+    if (temp < 0.0) {
+        // greedy sampling, with probs
+        llama_sample_softmax(ctx_main, &cur_p);
+        id = cur_p.data[0].id;
+    } else if (temp == 0.0) {
+        // greedy sampling, no probs
         id = llama_sample_token_greedy(ctx_main, &cur_p);
     } else {
         if (mirostat == 1) {
index f921b78455a72c4be7cc1ad86da78217f8132447..323c74652c9a6900abf26719ce320935bac00590 100644 (file)
@@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
     std::vector<seq_draft> drafts(n_seq_dft);
 
     params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
-    params.sparams.temp = std::max(0.01f, params.sparams.temp);
+    params.sparams.temp = -1.0f;    // force greedy sampling with probs for the draft model
 
     for (int s = 0; s < n_seq_dft; ++s) {
         drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
diff --git a/llama.h b/llama.h
index beac9a0cedd76c1f2c3d810f41446ee13c766d71..d901dcd9116d3de7b3017f15fe1c9ea1553fc28f 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -658,6 +658,7 @@ extern "C" {
                            float * mu);
 
     /// @details Selects the token with the highest probability.
+    ///          Does not compute the token probabilities. Use llama_sample_softmax() instead.
     LLAMA_API llama_token llama_sample_token_greedy(
             struct llama_context * ctx,
           llama_token_data_array * candidates);