]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : don't consider -infinity values in top_n_sigma (#13344)
authoroobabooga <redacted>
Tue, 6 May 2025 18:24:15 +0000 (15:24 -0300)
committerGitHub <redacted>
Tue, 6 May 2025 18:24:15 +0000 (20:24 +0200)
src/llama-sampling.cpp

index 0c9c6a3102929a72109172379ab86a4ddd58e56d..2869f60d204a148bcd47b622b1a18e00580831cc 100644 (file)
@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
     // find max logit and calculate mean
     float max = cur_p->data[0].logit;
     float logits_sum = 0;
+    size_t valid_count = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (cur_p->data[i].logit > max) {
-            max = cur_p->data[i].logit;
+        // Only count non-negative infinity values
+        if (cur_p->data[i].logit != -INFINITY) {
+            if (cur_p->data[i].logit > max) {
+                max = cur_p->data[i].logit;
+            }
+            logits_sum += cur_p->data[i].logit;
+            valid_count++;
         }
-        logits_sum += cur_p->data[i].logit;
     }
-    float mean = logits_sum/cur_p->size;
+    float mean = valid_count > 0 ? logits_sum/valid_count : 0;
 
     // calculate standard deviation
     float acc = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        acc += pow(cur_p->data[i].logit - mean, 2);
+        // Skip -infinity in std calculation
+        if (cur_p->data[i].logit != -INFINITY) {
+            acc += pow(cur_p->data[i].logit - mean, 2);
+        }
     }
-    float std = sqrt(acc/cur_p->size);
+    float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
 
     //apply mask
     for (size_t i = 0; i < cur_p->size; ++i) {