]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : add support for batch size to `--perplexity` (#407)
authorGary Linscott <redacted>
Thu, 13 Apr 2023 21:50:42 +0000 (14:50 -0700)
committerGitHub <redacted>
Thu, 13 Apr 2023 21:50:42 +0000 (00:50 +0300)
* Add support to batch size for perplexity

* Revert "Fix memory allocation issues and seg faults"

This reverts commit 4870e455b3653f7d7769fa5772b2c90ffad088df.

* update from merge

* Remove perplexity from main

* updates

* Update batch size for efficiency

examples/perplexity/perplexity.cpp

index b62f00d0cf110f34095b158d15f47979550ca313..38e3643b1ca5eb5f96a59f58d4a7720566eb988c 100644 (file)
@@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
 
     int count = 0;
     int seq_count = tokens.size() / params.n_ctx;
+    int n_vocab = llama_n_vocab(ctx);
 
     double nll = 0.0;
-
-    fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
+    fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
 
     for (int i = 0; i < seq_count; ++i) {
         int start = i * params.n_ctx;
-        int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
-                                            //       it is better to always be power of 2 for better performance
-        std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
+        int end = start + params.n_ctx;
+
+        std::vector<float> logits;
+        int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
         auto start_t = std::chrono::high_resolution_clock::now();
-        if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
-            return;
+        for (int j = 0; j < num_batches; ++j) {
+            int batch_start = start + j * params.n_batch;
+            int batch_size = std::min(end - batch_start, params.n_batch);
+            if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
+                fprintf(stderr, "%s : failed to eval\n", __func__);
+                return;
+            }
+            auto batch_logits = llama_get_logits(ctx);
+            logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
         }
         auto end_t = std::chrono::high_resolution_clock::now();
         if (i == 0) {
@@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
         // Example, we have a context window of 512, we will compute perplexity for each of the
         // last 256 tokens.  Then, we split the input up into context window size chunks to
         // process the entire prompt.
-
-        auto logits = llama_get_logits(ctx);
-        for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
+        for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
             // Calculate probability of next token, given the previous ones.
-            int n_vocab = llama_n_vocab(ctx);
             std::vector<float> tok_logits(
-                logits + j * n_vocab,
-                logits + (j + 1) * n_vocab);
-            const float prob = softmax(tok_logits)[tokens[start + j + 1]];
+                logits.begin() + j * n_vocab,
+                logits.begin() + (j + 1) * n_vocab);
+            float prob = softmax(tok_logits)[tokens[start + j + 1]];
             nll += -std::log(prob);
             ++count;
         }
@@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
     gpt_params params;
     params.model = "models/llama-7B/ggml-model.bin";
 
+    params.n_batch = 512;
     if (gpt_params_parse(argc, argv, params) == false) {
         return 1;
     }
 
     params.perplexity = true;
+    params.n_batch = std::min(params.n_batch, params.n_ctx);
 
     if (params.n_ctx > 2048) {
         fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"