]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
More efficient Hellaswag implementation (#2677)
authorKawrakow <redacted>
Sun, 20 Aug 2023 13:44:46 +0000 (16:44 +0300)
committerGitHub <redacted>
Sun, 20 Aug 2023 13:44:46 +0000 (16:44 +0300)
Co-authored-by: Iwan Kawrakow <redacted>
examples/perplexity/perplexity.cpp

index b9b28a20b58aee2d904f8e0c9397e5e39207de04..682c39b16894e93db5eaacd17ac363350bf860e9 100644 (file)
@@ -5,6 +5,7 @@
 #include <cmath>
 #include <ctime>
 #include <sstream>
+#include <cstring>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -209,50 +210,97 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     double acc = 0.0f;
     const int n_vocab = llama_n_vocab(ctx);
 
+    std::vector<float> tok_logits(n_vocab);
+
     for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
 
         // Tokenize the context to count tokens
         std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
         size_t context_size = context_embd.size();
 
-        for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
+        // Do the 1st ending
+        // In this case we include the context when evaluating
+        auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
+        auto query_size = query_embd.size();
+        //printf("First query: %d\n",(int)query_size);
+
+        // Stop if query wont fit the ctx window
+        if (query_size > (size_t)params.n_ctx) {
+            fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
+            return;
+        }
+
+        // Speedup small evaluations by evaluating atleast 32 tokens
+        if (query_size < 32) {
+            query_embd.resize(32);
+        }
+
+        // Evaluate the query
+        if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return;
+        }
+
+        auto query_logits = llama_get_logits(ctx);
+
+        std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
+        const auto first_probs = softmax(tok_logits);
+
+        hs_data[task_idx].ending_logprob_count[0] = 1;
+        hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
+
+        // Calculate the logprobs over the ending
+        for (size_t j = context_size; j < query_size - 1; j++) {
+
+            std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
+
+            const float prob = softmax(tok_logits)[query_embd[j + 1]];
+
+            hs_data[task_idx].ending_logprob[0] += std::log(prob);
+            hs_data[task_idx].ending_logprob_count[0]++;
+        }
+
+        // Calculate the mean token logprob for acc_norm
+        hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
+
+        // Do the remaining endings
+        // For these, we use the bare ending with n_past = context_size
+        //
+        for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
 
             // Tokenize the query
-            std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
-            size_t query_size = query_embd.size();
+            query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
+            query_size = query_embd.size();
+            //printf("Second query: %d\n",(int)query_size);
 
             // Stop if query wont fit the ctx window
-            if (query_size > (size_t)params.n_ctx) {
+            if (context_size + query_size > (size_t)params.n_ctx) {
                 fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
                 return;
             }
 
             // Speedup small evaluations by evaluating atleast 32 tokens
-            if (query_size < 32) {
-                query_embd.resize(32);
-            }
+            // No, resizing to 32 is actually slightly slower (at least on CUDA)
+            //if (query_size < 32) {
+            //    query_embd.resize(32);
+            //}
 
             // Evaluate the query
-            if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+            if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
                 fprintf(stderr, "%s : failed to eval\n", __func__);
                 return;
             }
 
-            const auto query_logits = llama_get_logits(ctx);
-            std::vector<float> logits;
-            logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
+            query_logits = llama_get_logits(ctx);
 
-            hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
-            hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
+            hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
+            hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
 
             // Calculate the logprobs over the ending
-            for (size_t j = context_size-1; j < query_size - 1; j++) {
-                // Calculate probability of next token, given the previous ones.
-                const std::vector<float> tok_logits(
-                    logits.begin() + (j + 0) * n_vocab,
-                    logits.begin() + (j + 1) * n_vocab);
+            for (size_t j = 0; j < query_size - 1; j++) {
+                std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
 
-                const float prob = softmax(tok_logits)[query_embd[ j + 1]];
+                const float prob = softmax(tok_logits)[query_embd[j + 1]];
 
                 hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
                 hs_data[task_idx].ending_logprob_count[ending_idx]++;
@@ -267,9 +315,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         }
 
         // Find the ending with maximum logprob
-        size_t ending_logprob_max_idx = -1;
-        double ending_logprob_max_val = -INFINITY;
-        for (size_t j=0; j < 4; j++) {
+        size_t ending_logprob_max_idx = 0;
+        double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
+        for (size_t j = 1; j < 4; j++) {
             if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
                 ending_logprob_max_idx = j;
                 ending_logprob_max_val =  hs_data[task_idx].ending_logprob[j];