]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : make sure samplers return at least 1 token (#13822)
authorGeorgi Gerganov <redacted>
Tue, 27 May 2025 09:07:52 +0000 (12:07 +0300)
committerGitHub <redacted>
Tue, 27 May 2025 09:07:52 +0000 (12:07 +0300)
* sampling : min-p should always return at least one token

ggml-ci

* sampling : same for typical sampling

* tests : sampling tests use min_keep == 0

ggml-ci

src/llama-sampling.cpp
tests/test-sampling.cpp

index 804b11e0a943e9625c78516c5da629ec91261968..bfbf5fa23011240c0dec57b390670ef1ff47079b 100644 (file)
@@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
         }
 
         // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= ctx->min_keep) {
+        if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
             memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
             cur_p->size = filtered_tokens.size();
             min_p_applied = true;
@@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
         cum_sum += cur_p->data[idx].p;
 
         // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
+        if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
             last_idx = i + 1;
             break;
         }
index 60ac62b385f352876295b23bbd257ce216b9a9ee..6300f25caebe3090fb2a157f44ea5d49a4c3941e 100644 (file)
@@ -98,7 +98,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
     sampler_tester tester(probs, probs_expected);
 
     DUMP(&tester.cur_p);
-    tester.apply(llama_sampler_init_top_p(p, 1));
+    tester.apply(llama_sampler_init_top_p(p, 0));
     tester.apply(llama_sampler_init_dist (0));
     DUMP(&tester.cur_p);
 
@@ -109,7 +109,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
     sampler_tester tester(probs, probs_expected);
 
     DUMP(&tester.cur_p);
-    tester.apply(llama_sampler_init_min_p(p, 1));
+    tester.apply(llama_sampler_init_min_p(p, 0));
     tester.apply(llama_sampler_init_dist (0));
     DUMP(&tester.cur_p);
 
@@ -130,7 +130,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
     sampler_tester tester(probs, probs_expected);
 
     DUMP(&tester.cur_p);
-    tester.apply(llama_sampler_init_typical(p, 1));
+    tester.apply(llama_sampler_init_typical(p, 0));
     DUMP(&tester.cur_p);
 
     tester.check();
@@ -332,6 +332,7 @@ int main(void) {
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f},                       0.74f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  0.76f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.00f);
+    test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.05f);
 
     printf("XTC should:\n");
     test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.1f},                                0.99f, 0.09f);
@@ -341,8 +342,8 @@ int main(void) {
     printf("XTC should not:\n");
     test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.4f, 0.3f, 0.2f, 0.1f},              0.99f, 0.39f);
 
-    test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
-    test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
+    test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f},            0.5f);
+    test_typical({0.4f, 0.2f, 0.2f, 0.2f},     {0.2f, 0.2f, 0.2f}, 0.5f);
 
     test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0},   50.0f, 0.0f, 0.0f);
     test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0},       50.0f, 0.0f, 0.0f);