]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : support using multiple sequences to allow larger batch sizes (#5946)
authorslaren <redacted>
Sat, 9 Mar 2024 18:55:54 +0000 (19:55 +0100)
committerGitHub <redacted>
Sat, 9 Mar 2024 18:55:54 +0000 (19:55 +0100)
* perplexity : support using multiple sequences to allow larger batch sizes

ggml-ci

* set cparams.n_parallel to the number of sequences

* print tested n_ctx, add assert

examples/perplexity/perplexity.cpp
llama.cpp

index 52789ee631234d28c8b4b4a214258980abf996cc..293eb52c33653e3bfdc062509fcac5e2bdf0ad33 100644 (file)
@@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
     return {tokens, std::exp(nll / count), logit_history, prob_history};
 }
 
-static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
+static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
     if (params.ppl_stride > 0) {
         return perplexity_v2(ctx, params);
     }
@@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     // BOS tokens will be added for each chunk before eval
 
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-    const int n_ctx = llama_n_ctx(ctx);
 
     std::ofstream logits_stream;
     if (!params.logits_file.empty()) {
@@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     double nll2 = 0.0;
 
     const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+    const int n_seq = std::max(1, n_batch / n_ctx);
+
+    GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
+    GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
+
+    llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
 
     std::vector<float> logits;
     if (num_batches > 1) {
         logits.reserve((size_t)n_ctx * n_vocab);
     }
 
-    fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+    fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
 
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 
@@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         log_probs.resize(n_ctx * nv);
     }
 
-    for (int i = 0; i < n_chunk; ++i) {
+    // We get the logits for all the tokens in the context window (params.n_ctx)
+    // from llama_eval above.  Now, based on https://huggingface.co/docs/transformers/perplexity,
+    // calculate the perplexity over the last half of the window (so the model always has
+    // some context to predict the token).
+    //
+    // We rely on the fact that attention in the forward pass only looks at previous
+    // tokens here, so the logits returned for each token are an accurate representation
+    // of what the model would have predicted at that point.
+    //
+    // Example, we have a context window of 512, we will compute perplexity for each of the
+    // last 256 tokens.  Then, we split the input up into context window size chunks to
+    // process the entire prompt.
+    const int first = n_ctx/2;
+
+    for (int i = 0; i < n_chunk; i += n_seq) {
         const int start =     i * n_ctx;
         const int end   = start + n_ctx;
 
+        const int n_seq_batch = std::min(n_seq, n_chunk - i);
+
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
@@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             const int batch_start = start + j * n_batch;
             const int batch_size  = std::min(end - batch_start, n_batch);
 
-            // save original token and restore it after eval
-            const auto token_org = tokens[batch_start];
+            batch.n_tokens = 0;
+            for (int seq = 0; seq < n_seq_batch; seq++) {
+                int seq_start = batch_start + seq*n_ctx;
 
-            // add BOS token for the first batch of each chunk
-            if (add_bos && j == 0) {
-                tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
+                // save original token and restore it after eval
+                const auto token_org = tokens[seq_start];
+
+                // add BOS token for the first batch of each chunk
+                if (add_bos && j == 0) {
+                    tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
+                }
+
+                for (int k = 0; k < batch_size; ++k) {
+                    const int idx = seq*n_ctx + k;
+                    batch.token[idx] = tokens[seq_start + k];
+                    batch.pos[idx] = j*n_batch + k;
+                    batch.n_seq_id[idx] = 1;
+                    batch.seq_id[idx][0] = seq;
+                    batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
+                }
+                batch.n_tokens += batch_size;
+
+                // restore the original token in case it was set to BOS
+                tokens[seq_start] = token_org;
             }
 
-            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
+            if (llama_decode(ctx, batch)) {
                 fprintf(stderr, "%s : failed to eval\n", __func__);
                 return {tokens, -1, logit_history, prob_history};
             }
 
-            // restore the original token in case it was set to BOS
-            tokens[batch_start] = token_org;
-
             if (num_batches > 1) {
                 const auto * batch_logits = llama_get_logits(ctx);
                 logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
@@ -558,7 +594,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         if (i == 0) {
             const float t_total = std::chrono::duration<float>(t_end - t_start).count();
             fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
-            int total_seconds = (int)(t_total * n_chunk);
+            int total_seconds = (int)(t_total*n_chunk/n_seq);
             if (total_seconds >= 60*60) {
                 fprintf(stderr, "%d hours ", total_seconds / (60*60));
                 total_seconds = total_seconds % (60*60);
@@ -566,37 +602,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
         }
 
-        // We get the logits for all the tokens in the context window (params.n_ctx)
-        // from llama_eval above.  Now, based on https://huggingface.co/docs/transformers/perplexity,
-        // calculate the perplexity over the last half of the window (so the model always has
-        // some context to predict the token).
-        //
-        // We rely on the fact that attention in the forward pass only looks at previous
-        // tokens here, so the logits returned for each token are an accurate representation
-        // of what the model would have predicted at that point.
-        //
-        // Example, we have a context window of 512, we will compute perplexity for each of the
-        // last 256 tokens.  Then, we split the input up into context window size chunks to
-        // process the entire prompt.
-        const int first = n_ctx/2;
-        const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
-        if (!params.logits_file.empty()) {
-            process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                    workers, log_probs, nll, nll2);
-        } else {
-            process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                    workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
-        }
-        count += n_ctx - first - 1;
-
-        // perplexity is e^(average negative log-likelihood)
-        if (params.ppl_output_type == 0) {
-            printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
-        } else {
-            double av = nll/count;
-            double av2 = nll2/count - av*av;
-            if (av2 > 0) av2 = sqrt(av2/(count-1));
-            printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
+        for (int seq = 0; seq < n_seq_batch; seq++) {
+            const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
+            llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
+            if (!params.logits_file.empty()) {
+                process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
+                        tokens_data, n_ctx - 1 - first,
+                        workers, log_probs, nll, nll2);
+            } else {
+                process_logits(n_vocab, all_logits + first*n_vocab,
+                        tokens_data, n_ctx - 1 - first,
+                        workers, nll, nll2,
+                        logit_history.data() + start + seq*n_ctx + first,
+                        prob_history.data()  + start + seq*n_ctx + first);
+            }
+            count += n_ctx - first - 1;
+
+            // perplexity is e^(average negative log-likelihood)
+            if (params.ppl_output_type == 0) {
+                printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
+            } else {
+                double av = nll/count;
+                double av2 = nll2/count - av*av;
+                if (av2 > 0) av2 = sqrt(av2/(count-1));
+                printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
+            }
         }
         fflush(stdout);
 
@@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         printf("Unexpected negative standard deviation of log(prob)\n");
     }
 
+    llama_batch_free(batch);
+
     return {tokens, ppl, logit_history, prob_history};
 }
 
@@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
 int main(int argc, char ** argv) {
     gpt_params params;
 
-    params.n_batch = 512;
     if (!gpt_params_parse(argc, argv, params)) {
         return 1;
     }
 
     params.logits_all = true;
-    params.n_batch = std::min(params.n_batch, params.n_ctx);
+
+    const int32_t n_ctx = params.n_ctx;
+
+    const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
+    if (ppl) {
+        int n_seq = std::max(1, params.n_batch / n_ctx);
+        int32_t n_kv = n_seq * n_ctx;
+        params.n_parallel = n_seq;
+        params.n_ctx = n_kv;
+        params.n_batch = std::min(params.n_batch, n_kv);
+    } else {
+        params.n_batch = std::min(params.n_batch, params.n_ctx);
+    }
 
     if (params.ppl_stride > 0) {
         fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
@@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
     } else if (params.kl_divergence) {
         kl_divergence(ctx, params);
     } else {
-        results = perplexity(ctx, params);
+        results = perplexity(ctx, params, n_ctx);
     }
 
     llama_print_timings(ctx);
index c58a029f74faff13badd37d1495d8e7c6984e9df..b19616e8f9a5fab04de37fbe3ef22560dd03c982 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -8925,17 +8925,29 @@ static int llama_decode_internal(
 
         if (batch.logits) {
             logits_out.resize(n_vocab * n_tokens);
+            int32_t i_first = -1;
             for (uint32_t i = 0; i < n_tokens; i++) {
-                if (batch.logits[i] == 0) {
-                    continue;
+                if (batch.logits[i] && i_first == -1) {
+                    i_first = (int32_t) i;
+                }
+                if (batch.logits[i] == 0 || i == n_tokens - 1) {
+                    if (i_first != -1) {
+                        int i_last = batch.logits[i] == 0 ? i : i + 1;
+                        // extract logits for the range [i_first, i_last)
+                        // group the requests to minimize the number of calls to the backend
+                        ggml_backend_tensor_get_async(backend_res, res,
+                            logits_out.data() + (n_vocab*i_first),
+                            (n_vocab*i_first)*sizeof(float),
+                            (i_last - i_first)*n_vocab*sizeof(float));
+                        i_first = -1;
+                    }
                 }
-                ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
 #ifndef NDEBUG
-                logits_valid[i] = true;
+                logits_valid[i] = batch.logits[i] != 0;
 #endif
             }
         } else if (lctx.logits_all) {
-            logits_out.resize(n_vocab * n_tokens);
+            logits_out.resize(n_vocab*n_tokens);
             ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
 #ifndef NDEBUG
             std::fill(logits_valid.begin(), logits_valid.end(), true);