]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Perplexity: Compute scores correlated to HellaSwag (#2312)
authorklosax <redacted>
Sat, 22 Jul 2023 12:21:24 +0000 (14:21 +0200)
committerGitHub <redacted>
Sat, 22 Jul 2023 12:21:24 +0000 (14:21 +0200)
* Add parameter --perplexity-lines to perplexity.cpp

examples/common.cpp
examples/common.h
examples/perplexity/perplexity.cpp

index 09901959956f9d92686719722afe6f4e2a245b3e..730b28bde957b960794711c7e57ab7ee400f19bd 100644 (file)
@@ -387,6 +387,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.antiprompt.push_back(argv[i]);
         } else if (arg == "--perplexity") {
             params.perplexity = true;
+        } else if (arg == "--perplexity-lines") {
+            params.perplexity_lines = true;
         } else if (arg == "--ignore-eos") {
             params.logit_bias[llama_token_eos()] = -INFINITY;
         } else if (arg == "--no-penalize-nl") {
@@ -512,7 +514,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "                        not recommended: doubles context memory required and no measurable increase in quality\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
     fprintf(stderr, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
-    fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n");
+    fprintf(stderr, "  --perplexity          compute perplexity over each ctx window of the prompt\n");
+    fprintf(stderr, "  --perplexity-lines    compute perplexity over each line of the prompt\n");
     fprintf(stderr, "  --keep                number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
     fprintf(stderr, "  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
     if (llama_mlock_supported()) {
index 69170dfc098502ca20b9f67cacaeb782910297d6..c936de6fad659f87094b81bcdb471f73de6798f0 100644 (file)
@@ -82,6 +82,7 @@ struct gpt_params {
     bool instruct          = false; // instruction mode (used for Alpaca models)
     bool penalize_nl       = true;  // consider newlines as a repeatable token
     bool perplexity        = false; // compute perplexity over the prompt
+    bool perplexity_lines  = false; // compute perplexity over each line of the prompt
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
     bool mem_test          = false; // compute maximum memory usage
index bfad999394fedc2401ec8db4e27a28593afca790..d23b7e7f0c1b8f2ea693178c5a16466d5df081f7 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <cmath>
 #include <ctime>
+#include <sstream>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     printf("\n");
 }
 
+void perplexity_lines(llama_context * ctx, const gpt_params & params) {
+    // Calculates perplexity over each line of the prompt
+
+    std::vector<std::string> prompt_lines;
+    std::istringstream strstream(params.prompt);
+    std::string line;
+
+    while (std::getline(strstream,line,'\n')) {
+        prompt_lines.push_back(line);
+    }
+
+    const int n_vocab = llama_n_vocab(ctx);
+
+    int counttotal   = 0;
+    size_t n_lines = prompt_lines.size();
+
+    double nll = 0.0;
+
+    fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
+
+    printf("\nLine\tPPL line\tPPL cumulative\n");
+
+    for (size_t i = 0; i < n_lines; ++i) {
+
+        // Tokenize and insert BOS at start
+        std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
+
+        size_t batch_size  = batch_embd.size();
+
+        // Stop if line is too long
+        if( batch_size > (size_t)params.n_ctx ) {
+            fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
+            return;
+        }
+
+        if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return;
+        }
+
+        const auto batch_logits = llama_get_logits(ctx);
+        std::vector<float> logits;
+        logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+
+        double nllline = 0.0;
+        int countline = 0;
+
+        // Perplexity over second half of the line
+        for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
+            // Calculate probability of next token, given the previous ones.
+            const std::vector<float> tok_logits(
+                logits.begin() + (j + 0) * n_vocab,
+                logits.begin() + (j + 1) * n_vocab);
+
+            const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
+
+            nllline += -std::log(prob);
+            ++countline;
+        }
+
+        nll += nllline;
+        counttotal += countline;
+
+        // perplexity is e^(average negative log-likelihood)
+        printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
+        fflush(stdout);
+    }
+
+    printf("\n");
+}
+
 int main(int argc, char ** argv) {
     gpt_params params;
 
@@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
     }
 
-    perplexity(ctx, params);
+    if (params.perplexity_lines) {
+        perplexity_lines(ctx, params);
+    } else {
+        perplexity(ctx, params);
+    }
 
     llama_print_timings(ctx);
     llama_free(ctx);