]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling: fix top_k <= 0 (#5388)
authorJohannes Gäßler <redacted>
Thu, 8 Feb 2024 08:46:30 +0000 (09:46 +0100)
committerGitHub <redacted>
Thu, 8 Feb 2024 08:46:30 +0000 (09:46 +0100)
* sampling: fix top_k <= 0

* Update llama.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
common/sampling.cpp
llama.cpp
tests/test-sampling.cpp

index e8675a8c0c18902da77e8966b4ca81d6cea276b8..844ad7c53605deea9501689a1774b31f79571ba4 100644 (file)
@@ -132,7 +132,7 @@ static void sampler_queue(
     const float         temp              = params.temp;
     const float         dynatemp_range    = params.dynatemp_range;
     const float         dynatemp_exponent = params.dynatemp_exponent;
-    const int32_t       top_k             = params.top_k <= 0 ? n_vocab : params.top_k;
+    const int32_t       top_k             = params.top_k;
     const float         top_p             = params.top_p;
     const float         min_p             = params.min_p;
     const float         tfs_z             = params.tfs_z;
index c45ae1d5088bffd173d6965af5a67cf9dea3964e..f8f5796a4381474328bf9d5ceb453bacba81af57 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -8585,6 +8585,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
     // }
 
     const int64_t t_start_sample_us = ggml_time_us();
+    
+    if (k <= 0) {
+        k = candidates->size;
+    }
 
     k = std::max(k, (int) min_keep);
     k = std::min(k, (int) candidates->size);
index c3b3d6629d4ba2366049152f8ce7dc48aa5ad8cc..6374958fee8e6ada12303dbec4e1eaea7a251911 100644 (file)
@@ -235,6 +235,8 @@ int main(void) {
 
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
 
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);