]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Strided perplexity (#2714)
authorKawrakow <redacted>
Wed, 23 Aug 2023 09:56:42 +0000 (12:56 +0300)
committerGitHub <redacted>
Wed, 23 Aug 2023 09:56:42 +0000 (12:56 +0300)
* Implementing strided computation of perplexity

* Alternative way to output PPL results

---------

Co-authored-by: Iwan Kawrakow <redacted>
common/common.cpp
common/common.h
examples/perplexity/perplexity.cpp

index 2a83b379ec4f56a117a950a34518f8a277f63ccb..88a962ae385de5f5d918dd81412465edc6554b55 100644 (file)
@@ -417,6 +417,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.antiprompt.push_back(argv[i]);
         } else if (arg == "--perplexity") {
             params.perplexity = true;
+        } else if (arg == "--ppl-stride") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.ppl_stride = std::stoi(argv[i]);
+        } else if (arg == "--ppl-output-type") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.ppl_output_type = std::stoi(argv[i]);
         } else if (arg == "--hellaswag") {
             params.hellaswag = true;
         } else if (arg == "--hellaswag-tasks") {
index 18fd951ead9df6581ce4a3fd993b4ed52eff9cf1..d68a8ef88c97cc1ad924189abf3e7ee88423c45c 100644 (file)
@@ -64,6 +64,10 @@ struct gpt_params {
     std::string lora_adapter = "";  // lora adapter path
     std::string lora_base    = "";  // base model path for the lora adapter
 
+    int  ppl_stride        = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+    int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+                                    //                                       (which is more convenient to use for plotting)
+                                    //
     bool hellaswag         = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 
index f3c045aeca2b26409062dbcf6b642f70096e5885..e89725efc3db6bedbec50c6984553b4c8bfb9904 100644 (file)
@@ -27,7 +27,121 @@ std::vector<float> softmax(const std::vector<float>& logits) {
     return probs;
 }
 
+void perplexity_v2(llama_context * ctx, const gpt_params & params) {
+
+    // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
+    // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
+    // Output: `perplexity: 13.5106 [114/114]`
+    // BOS tokens will be added for each chunk before eval
+
+    if (params.ppl_stride <= 0) {
+        fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
+        return;
+    }
+    auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+    const int calc_chunk = params.n_ctx;
+
+    fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
+
+    if (int(tokens.size()) <= calc_chunk) {
+        fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
+                tokens.size(), params.n_ctx, params.ppl_stride);
+        return;
+    }
+
+    const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1)  / params.ppl_stride;
+
+    const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
+    const int n_vocab = llama_n_vocab(ctx);
+    const int n_batch = params.n_batch;
+
+    int count = 0;
+    double nll = 0.0;
+
+    fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+
+    for (int i = 0; i < n_chunk; ++i) {
+        const int start =     i * params.ppl_stride;
+        const int end   = start + calc_chunk;
+
+        const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
+        //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
+
+        std::vector<float> logits;
+
+        const auto t_start = std::chrono::high_resolution_clock::now();
+
+        for (int j = 0; j < num_batches; ++j) {
+            const int batch_start = start + j * n_batch;
+            const int batch_size  = std::min(end - batch_start, n_batch);
+
+            //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
+            if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
+                //fprintf(stderr, "%s : failed to eval\n", __func__);
+                return;
+            }
+
+            // save original token and restore it after eval
+            const auto token_org = tokens[batch_start];
+
+            // add BOS token for the first batch of each chunk
+            if (j == 0) {
+                tokens[batch_start] = llama_token_bos(ctx);
+            }
+
+            const auto batch_logits = llama_get_logits(ctx);
+            logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+
+            if (j == 0) {
+                tokens[batch_start] = token_org;
+            }
+        }
+
+        const auto t_end = std::chrono::high_resolution_clock::now();
+
+        if (i == 0) {
+            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
+            fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
+            int total_seconds = (int)(t_total * n_chunk);
+            if (total_seconds >= 60*60) {
+                fprintf(stderr, "%d hours ", total_seconds / (60*60));
+                total_seconds = total_seconds % (60*60);
+            }
+            fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
+        }
+
+        //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
+        for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
+
+            // Calculate probability of next token, given the previous ones.
+            const std::vector<float> tok_logits(
+                logits.begin() + (j + 0) * n_vocab,
+                logits.begin() + (j + 1) * n_vocab);
+
+            const float prob = softmax(tok_logits)[tokens[start + j + 1]];
+
+            nll += -std::log(prob);
+            ++count;
+        }
+        // perplexity is e^(average negative log-likelihood)
+        if (params.ppl_output_type == 0) {
+            printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+        } else {
+            printf("%8d  %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
+        }
+        fflush(stdout);
+    }
+    printf("\n");
+}
+
 void perplexity(llama_context * ctx, const gpt_params & params) {
+
+    if (params.ppl_stride > 0) {
+        perplexity_v2(ctx, params);
+        return;
+    }
+
     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
     // Output: `perplexity: 13.5106 [114/114]`
@@ -116,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
             ++count;
         }
         // perplexity is e^(average negative log-likelihood)
-        printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+        if (params.ppl_output_type == 0) {
+            printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+        } else {
+            printf("%8d  %.4lf\n", i*params.n_ctx, std::exp(nll / count));
+        }
         fflush(stdout);
     }
     printf("\n");
@@ -369,6 +487,12 @@ int main(int argc, char ** argv) {
     params.perplexity = true;
     params.n_batch = std::min(params.n_batch, params.n_ctx);
 
+    if (params.ppl_stride > 0) {
+        fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
+                params.n_ctx, params.n_ctx + params.ppl_stride/2);
+        params.n_ctx += params.ppl_stride/2;
+    }
+
     if (params.n_ctx > 2048) {
         fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
                 "expect poor results\n", __func__, params.n_ctx);