]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Apply min_p to unsorted tokens (#5115)
authorJohannes Gäßler <redacted>
Sun, 28 Jan 2024 08:59:49 +0000 (09:59 +0100)
committerGitHub <redacted>
Sun, 28 Jan 2024 08:59:49 +0000 (09:59 +0100)
llama.cpp

index 391c956ec19a6870261513b80dd76ec227f04746..45eeadbe882d95ecd6490db459916774e626fcae 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -52,6 +52,7 @@
 #include <algorithm>
 #include <array>
 #include <cassert>
+#include <cfloat>
 #include <cinttypes>
 #include <climits>
 #include <cmath>
@@ -8246,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
         return;
     }
 
-    llama_sample_softmax(ctx, candidates);
-
     const int64_t t_start_sample_us = ggml_time_us();
 
-    float scale = candidates->data[0].p; // scale by max prob
-    size_t i = 1; // first token always matches
+    bool min_p_applied = false;
+
+    // if the candidates aren't sorted, try the unsorted implementation first
+    if (!candidates->sorted) {
+        std::vector<llama_token_data> filtered_tokens;
+
+        float max_logit = -FLT_MAX;
+        for (size_t i = 0; i < candidates->size; ++i) {
+            max_logit = std::max(max_logit, candidates->data[i].logit);
+        }
+        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
+
+        for (size_t i = 0; i < candidates->size; ++i) {
+            if (candidates->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(candidates->data[i]);
+            }
+        }
 
-    for (; i < candidates->size; ++i) {
-        if (candidates->data[i].p < p * scale && i >= min_keep) {
-            break; // prob too small
+        // if we have enough values the operation was a success
+        if (filtered_tokens.size() >= min_keep) {
+            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            candidates->size = filtered_tokens.size();
+            min_p_applied = true;
         }
     }
 
-    // Resize the output vector to keep only the matching tokens
-    candidates->size = i;
+    // if the candidates are sorted or the unsorted implementation failed, use this implementation
+    if (!min_p_applied) {
+        // Sort the logits in descending order
+        if (!candidates->sorted) {
+            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+                return a.logit > b.logit;
+            });
+            candidates->sorted = true;
+        }
+
+        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+        size_t i = 1; // first token always matches
+
+        for (; i < candidates->size; ++i) {
+            if (candidates->data[i].logit < min_logit && i >= min_keep) {
+                break; // prob too small
+            }
+        }
+
+        // Resize the output vector to keep only the matching tokens
+        candidates->size = i;
+    }
 
     if (ctx) {
         ctx->t_sample_us += ggml_time_us() - t_start_sample_us;