]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Faster perplexity computation (#2786)
authorKawrakow <redacted>
Fri, 25 Aug 2023 16:05:02 +0000 (19:05 +0300)
committerGitHub <redacted>
Fri, 25 Aug 2023 16:05:02 +0000 (19:05 +0300)
Co-authored-by: Iwan Kawrakow <redacted>
examples/perplexity/perplexity.cpp

index a7bd9db2a3fd323286cad1fc667c5d3367aad22f..18635932b3a39ddf66b04f43f1a4ee4269fb0630 100644 (file)
@@ -6,6 +6,8 @@
 #include <ctime>
 #include <sstream>
 #include <cstring>
+#include <thread>
+#include <mutex>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -27,6 +29,40 @@ std::vector<float> softmax(const std::vector<float>& logits) {
     return probs;
 }
 
+float log_softmax(int n_vocab, const float * logits, int tok) {
+    float max_logit = logits[0];
+    for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
+    double sum_exp = 0.0;
+    for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
+    return logits[tok] - max_logit - log(sum_exp);
+}
+
+void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
+        double& nll, double& nll2) {
+
+    std::mutex mutex;
+    int counter = 0;
+    auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
+        double local_nll = 0, local_nll2 = 0;
+        while (true) {
+            std::unique_lock<std::mutex> lock(mutex);
+            int i = counter++;
+            if (i >= n_token) {
+                nll += local_nll; nll2 += local_nll2;
+                break;
+            }
+            lock.unlock();
+            double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
+            local_nll += v;
+            local_nll2 += v*v;
+        }
+    };
+    for (auto& w : workers) w = std::thread(compute);
+    compute();
+    for (auto& w : workers) w.join();
+
+}
+
 void perplexity_v2(llama_context * ctx, const gpt_params & params) {
     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
@@ -166,9 +202,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
 
     int count = 0;
     double nll = 0.0;
+    double nll2 = 0.0;
 
     fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
 
+    std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
+
     for (int i = 0; i < n_chunk; ++i) {
         const int start =     i * params.n_ctx;
         const int end   = start + params.n_ctx;
@@ -228,26 +267,32 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
         // Example, we have a context window of 512, we will compute perplexity for each of the
         // last 256 tokens.  Then, we split the input up into context window size chunks to
         // process the entire prompt.
-        for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
-            // Calculate probability of next token, given the previous ones.
-            const std::vector<float> tok_logits(
-                logits.begin() + (j + 0) * n_vocab,
-                logits.begin() + (j + 1) * n_vocab);
-
-            const float prob = softmax(tok_logits)[tokens[start + j + 1]];
+        const int first = std::min(512, params.n_ctx/2);
+        process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
+        count += params.n_ctx - first - 1;
 
-            nll += -std::log(prob);
-            ++count;
-        }
         // perplexity is e^(average negative log-likelihood)
         if (params.ppl_output_type == 0) {
             printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
         } else {
-            printf("%8d  %.4lf\n", i*params.n_ctx, std::exp(nll / count));
+            double av = nll/count;
+            double av2 = nll2/count - av*av;
+            if (av2 > 0) av2 = sqrt(av2/(count-1));
+            printf("%8d  %.4lf  %4lf  %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
         }
         fflush(stdout);
     }
     printf("\n");
+    nll2 /= count;
+    nll /= count;
+    nll2 -= nll * nll;
+    if (nll2 > 0) {
+        nll2 = sqrt(nll2/(count-1));
+        double ppl = exp(nll);
+        printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
+    } else {
+        printf("Unexpected negative standard deviation of log(prob)\n");
+    }
 }
 
 std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,