LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
- LOG("\ntask\tacc_norm\n");
+ LOG("\ntask\tacc_norm\t95%% confidence interval\n");
double acc = 0.0f;
acc += 1.0;
}
- // Print the accumulated accuracy mean x 100
- LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
+ double freq = acc / double(i + 1);
+
+ const double za = 1.95996398454;
+
+ // // Wald normal approx
+ // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
+ // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
+
+ // Wilson score interval, more accurate
+ double z = za * za / double(i + 1);
+ double cnf = z * sqrt(double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
+ double a = (freq + z * 0.5 - cnf) / (1.0 + z);
+ double b = (freq + z * 0.5 + cnf) / (1.0 + z);
+
+ // Print the accumulated accuracy mean x 100 and confidence interval
+ LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
}
i0 = i1 - 1;