]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Additional KL-divergence statistics (#5081)
authorKawrakow <redacted>
Tue, 23 Jan 2024 13:17:20 +0000 (15:17 +0200)
committerGitHub <redacted>
Tue, 23 Jan 2024 13:17:20 +0000 (15:17 +0200)
* perplexity: add top-token probability

* perplexity: add additional KL-divergence statistics

* perplexity: a better organized KL-divergence statistics output

---------

Co-authored-by: Iwan Kawrakow <redacted>
examples/perplexity/perplexity.cpp

index de6d3eb4137b901f9ddf43e8d45214ae4909d051..8d2204969c0cbb3d8690c8b1b98dbec678e13d0d 100644 (file)
@@ -222,13 +222,18 @@ struct kl_divergence_result {
     double sum_kld2 = 0;
     double sum_nll_diff  = 0;
     double sum_nll_diff2 = 0;
+    size_t n_same_top = 0;
     size_t count = 0;
 };
 
-static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
+static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
     float max_logit = logits[0];
+    int imax = 0;
     for (int i = 1; i < n_vocab; ++i) {
-        max_logit = std::max(max_logit, logits[i]);
+        if (logits[i] > max_logit) {
+            max_logit = logits[i];
+            imax = i;
+        }
     }
     double sum_exp = 0.0;
     for (int i = 0; i < n_vocab; ++i) {
@@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
     kld.sum_nll_diff2 += nll*nll;
     max_logit += log_sum_exp;
     double sum = 0;
+    int imax_base = -1;
+    float p_log_base_max = 0;
     for (int i = 0; i < n_vocab; ++i) {
         const float p_log_base = scale*base_log_prob[i] + min_log_prob;
+        if (i == 0 || p_log_base > p_log_base_max) {
+            p_log_base_max = p_log_base;
+            imax_base = i;
+        }
         if (p_log_base > -16.f) {
             const float p_base = expf(p_log_base);
             sum += p_base * (p_log_base - logits[i] + max_logit);
@@ -257,14 +268,17 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
     kld.sum_kld  += sum;
     kld.sum_kld2 += sum*sum;
     ++kld.count;
+    if (imax == imax_base) ++kld.n_same_top;
+    return sum;
 }
 
 static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
-        std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
+        std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
+        float * kld_values) {
     std::mutex mutex;
     const int nv = 2*((n_vocab + 1)/2) + 4;
     int counter = 0;
-    auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
+    auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () {
         kl_divergence_result local_kld;
         while (true) {
             std::unique_lock<std::mutex> lock(mutex);
@@ -276,11 +290,13 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
                 kld.sum_kld2 += local_kld.sum_kld2;
                 kld.sum_nll_diff  += local_kld.sum_nll_diff;
                 kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
+                kld.n_same_top += local_kld.n_same_top;
                 kld.count += local_kld.count;
                 break;
             }
             lock.unlock();
-            log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
+            double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
+            kld_values[i] = (float)v;
         }
     };
     for (auto & w : workers) {
@@ -1615,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     in.read((char *)&n_vocab, sizeof(n_vocab));
     in.read((char *)&n_chunk, sizeof(n_chunk));
     if (in.fail()) {
-        fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
+        fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
         return;
     }
     if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
@@ -1634,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
     std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
+    std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
     std::vector<float> logits;
     if (num_batches > 1) {
         logits.reserve(n_ctx * n_vocab);
@@ -1652,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     };
 
     kl_divergence_result kld;
+    auto kld_ptr = kld_values.data();
 
     for (int i = 0; i < n_chunk; ++i) {
         const int start =     i * n_ctx;
@@ -1705,20 +1723,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
             }
             fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
 
-            printf("\nchunk        PPL          ln(PPL(Q)/PPL(base))          KL-Divergence\n");
+            printf("\nchunk        PPL          ln(PPL(Q)/PPL(base))          KL-Divergence           Same top\n");
         }
 
         const int first = n_ctx/2;
         const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
         process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                workers, log_probs_uint16, kld);
+                workers, log_probs_uint16, kld, kld_ptr);
+        kld_ptr += n_ctx - 1 - first;
 
         auto ppl           = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
         auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
         auto kl_div        = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
+        auto p_top = 1.*kld.n_same_top/kld.count;
+        auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1));
 
-        printf("%4d    %10.4lf    %10.5lf ± %10.5f    %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
-                log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
+        printf("%4d    %10.4lf    %10.5lf ± %10.5f    %10.5f ± %10.5lf    %.5f ± %.5f\n", i+1, exp(ppl.first),
+                log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second,
+                p_top, d_p_top);
 
         fflush(stdout);
 
@@ -1726,6 +1748,35 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     }
     printf("\n");
 
+    if (kld.count < 100) return; // we do not wish to do statistics on so few values
+
+    std::sort(kld_values.begin(), kld_values.end());
+
+    printf("===== KL-divergence statistics\n");
+    auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
+    printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second);
+    auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
+                                               : kld_values[kld_values.size()/2];
+    printf("Median : %10.6f\n", kld_median);
+
+    auto percentile = [&kld_values] (float fraction) {
+        if (fraction <= 0) return kld_values.front();
+        if (fraction >= 1) return kld_values.back();
+        float p = fraction*(kld_values.size() - 1);
+        size_t ip = size_t(p); p -= ip;
+        return (1 - p)*kld_values[ip] + p*kld_values[std::min(ip+1, kld_values.size()-1)];
+    };
+
+    printf("Maximum: %10.6f\n", kld_values.back());
+    printf("KLD_99 : %10.6f\n", percentile(0.99f));
+    printf("KLD_95 : %10.6f\n", percentile(0.95f));
+    printf("KLD_90 : %10.6f\n", percentile(0.90f));
+
+    printf("Minimum: %10.6f\n", kld_values.front());
+    printf("KLD_01 : %10.6f\n", percentile(0.01f));
+    printf("KLD_05 : %10.6f\n", percentile(0.05f));
+    printf("KLD_10 : %10.6f\n", percentile(0.10f));
+
 }
 
 int main(int argc, char ** argv) {