]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : always sort logits before nucleus sampling (#812)
authorIvan Stepanov <redacted>
Fri, 7 Apr 2023 16:02:12 +0000 (19:02 +0300)
committerGitHub <redacted>
Fri, 7 Apr 2023 16:02:12 +0000 (19:02 +0300)
* Always sort logits before nucleus sampling

* remove second normalization

- fix windows build
- remove normalization since std::discrete_distribution does not require it

llama.cpp

index 581a8399d0229814aab9866b19d1f568dff7794e..978327a5b50d1532bfe4d34d5887562ff61242e6 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
         }
     }
 
-    if (top_k > 0 && top_k < n_logits) {
-        sample_top_k(logits_id, top_k);
-    }
-
-    float maxl = -std::numeric_limits<float>::infinity();
-    for (const auto & kv : logits_id) {
-        maxl = Max(maxl, kv.first);
-    }
+    sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
 
     // compute probs for the top k tokens
     std::vector<float> probs;
     probs.reserve(logits_id.size());
 
+    float maxl = logits_id[0].first;
     double sum = 0.0;
     for (const auto & kv : logits_id) {
         const float p = expf(kv.first - maxl);
@@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
                 break;
             }
         }
-
-        cumsum = 1.0/cumsum;
-        for (int i = 0; i < (int) probs.size(); i++) {
-            probs[i] *= cumsum;
-        }
     }
 
     //printf("\n");
     //for (int i = 0; i < (int) 10; i++) {
-    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+    //    printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
     //}
     //printf("\n\n");
     //exit(0);