]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
hellaswag: display estimated score confidence interval (#12797)
authorstduhpf <redacted>
Mon, 7 Apr 2025 15:47:08 +0000 (17:47 +0200)
committerGitHub <redacted>
Mon, 7 Apr 2025 15:47:08 +0000 (18:47 +0300)
examples/perplexity/perplexity.cpp

index 8c413f7d66e6d1c3d7c224fb1c17ecc88646aebb..175f2804b5da007afeee7cf4381e0e3d7de48588 100644 (file)
@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 
     LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
 
-    LOG("\ntask\tacc_norm\n");
+    LOG("\ntask\tacc_norm\t95%% confidence interval\n");
 
     double acc = 0.0f;
 
@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
                 acc += 1.0;
             }
 
-            // Print the accumulated accuracy mean x 100
-            LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
+            double freq = acc / double(i + 1);
+
+            const double za = 1.95996398454;
+
+            // // Wald normal approx
+            // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
+            // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
+
+            // Wilson score interval, more accurate
+            double z   = za * za / double(i + 1);
+            double cnf = z * sqrt(double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
+            double a   = (freq + z * 0.5 - cnf) / (1.0 + z);
+            double b   = (freq + z * 0.5 + cnf) / (1.0 + z);
+
+            // Print the accumulated accuracy mean x 100 and confidence interval
+            LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
         }
 
         i0 = i1 - 1;