]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix top-p sampling to match the canonical definition (#1953)
authorAlex Renda <redacted>
Sat, 24 Jun 2023 10:15:01 +0000 (03:15 -0700)
committerGitHub <redacted>
Sat, 24 Jun 2023 10:15:01 +0000 (13:15 +0300)
* Fix top-p sampling to match the standard definition (smallest set that has probability mass at least p, not largest set with probability mass less than p)

* top-p: correct gt to gte

* add test for correct top-p behavior

llama.cpp
tests/test-sampling.cpp

index a528eef4a902036552a1ac247cd9b2b1f9011f2a..ac22a48f8ab971781b2bf05ce9d9616634a4160b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
     for (size_t i = 0; i < candidates->size; ++i) {
         cum_sum += candidates->data[i].p;
 
-        // Check if the running sum is greater than p or if we have kept at least min_keep tokens
-        if (cum_sum > p && i >= min_keep) {
-            last_idx = i;
+        // Check if the running sum is at least p or if we have kept at least min_keep tokens
+        // we set the last index to i+1 to indicate that the current iterate should be included in the set
+        if (cum_sum >= p && i + 1 >= min_keep) {
+            last_idx = i + 1;
             break;
         }
     }
index 5d693f7b561a6771bd4eb26c60613de1f9d27ad5..64f9455d72e5437c5bea9c91cd73cbbd89cb818b 100644 (file)
@@ -181,6 +181,7 @@ int main(void) {
 
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
 
     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);