]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HellaSwag: speed up by parallelizing log-prob evaluation (#5020)
authorKawrakow <redacted>
Thu, 18 Jan 2024 17:18:21 +0000 (19:18 +0200)
committerGitHub <redacted>
Thu, 18 Jan 2024 17:18:21 +0000 (19:18 +0200)
For Mistral-7B and fp16, time on my system goes down from 536 seconds
to 423 seconds for the full evaluation dataset (10042 tasks).

Co-authored-by: Iwan Kawrakow <redacted>
examples/perplexity/perplexity.cpp

index ea2c8026cfcecfdbe5647a172e12d192dce95c09..9498dd535ee132f7a04edd41d9fe14a4c5a6bb83 100644 (file)
@@ -8,6 +8,7 @@
 #include <sstream>
 #include <thread>
 #include <mutex>
+#include <atomic>
 #include <vector>
 #include <array>
 #include <fstream>
@@ -444,6 +445,48 @@ static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int>
     return result;
 }
 
+static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
+        const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
+    constexpr int k_token_chunk = 4;
+    if (eval_results.size() != eval_pairs.size()) {
+        eval_results.resize(eval_pairs.size());
+    }
+    if (eval_pairs.empty()) return;
+
+    size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
+
+    std::atomic<int> counter(0);
+    auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
+        float local_logprobs[k_token_chunk];
+        while (true) {
+            size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
+            if (first >= eval_results.size()) break;
+            size_t last = std::min(first + k_token_chunk, eval_results.size());
+            for (size_t i = first; i < last; ++i) {
+                auto logits = batch_logits + eval_pairs[i].first * n_vocab;
+                float max_logit = logits[0];
+                for (int j = 1; j < n_vocab; ++j) {
+                    max_logit = std::max(max_logit, logits[j]);
+                }
+                float sum_p = 0.f;
+                for (int j = 0; j < n_vocab; ++j) {
+                    sum_p += expf(logits[j] - max_logit);
+                }
+                local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
+            }
+            std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
+        }
+    };
+
+    for (size_t it = 0; it < max_threads; ++it) {
+        workers[it] = std::thread(compute);
+    }
+    for (size_t it = 0; it < max_threads; ++it) {
+        workers[it].join();
+    }
+
+}
+
 static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     // Calculates hellaswag score (acc_norm) from prompt
     //
@@ -574,6 +617,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     std::vector<float> tok_logits(n_vocab);
     std::vector<float> batch_logits(n_ctx*n_vocab);
 
+    std::vector<std::pair<size_t, llama_token>> eval_pairs;
+    std::vector<float> eval_results;
+    std::vector<std::thread> workers(std::thread::hardware_concurrency());
+
     auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
@@ -654,6 +701,24 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             return;
         }
 
+        // Compute log-probs in parallel
+        // First we collect all tasks
+        eval_pairs.clear();
+        for (size_t i = i0; i < i1; ++i) {
+            auto & hs_cur = hs_data[i];
+            size_t li = hs_cur.common_prefix;
+            for (int s = 0; s < 4; ++s) {
+                for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
+                    eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]));
+                }
+                ++li;
+            }
+        }
+        // Then we do the actual calculation
+        hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
+
+        size_t ir = 0;
+
         // compute the logprobs for each ending of the decoded tasks
         for (size_t i = i0; i < i1; ++i) {
             auto & hs_cur = hs_data[i];
@@ -662,26 +727,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
 
             const auto first_probs = softmax(tok_logits);
 
-            size_t li = hs_cur.common_prefix; // logits index in the batch
-
             for (int s = 0; s < 4; ++s) {
                 hs_cur.ending_logprob_count[s] = 1;
                 hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
-
-                // Calculate the logprobs over the ending
                 for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
-                    std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
-
-                    const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
-
-                    hs_cur.ending_logprob[s] += std::log(prob);
+                    hs_cur.ending_logprob[s] += eval_results[ir++];
                     hs_cur.ending_logprob_count[s]++;
                 }
-
-                // account that we skip the last token in the ending
-                ++li;
-
-                // Calculate the mean token logprob for acc_norm
                 hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
             }