]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HellaSwag: split token evaluation into batches if needed (#2681)
authorKawrakow <redacted>
Mon, 21 Aug 2023 08:11:31 +0000 (11:11 +0300)
committerGitHub <redacted>
Mon, 21 Aug 2023 08:11:31 +0000 (11:11 +0300)
Co-authored-by: Iwan Kawrakow <redacted>
examples/perplexity/perplexity.cpp

index 682c39b16894e93db5eaacd17ac363350bf860e9..2409db69f1afd5615d846ac321084180d6502ffc 100644 (file)
@@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     printf("\n");
 }
 
+std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
+        int n_vocab, int n_thread) {
+    std::vector<float> result;
+    result.reserve(tokens.size() * n_vocab);
+    size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
+    for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
+        size_t n_tokens = tokens.size() - i_chunk * n_batch;
+        n_tokens = std::min(n_tokens, size_t(n_batch));
+        if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return {};
+        }
+
+        const auto logits = llama_get_logits(ctx);
+        result.insert(result.end(), logits, logits + n_tokens * n_vocab);
+
+        n_past += n_tokens;
+    }
+    return result;
+}
+
 void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     // Calculates hellaswag score (acc_norm) from prompt
     //
@@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             query_embd.resize(32);
         }
 
-        // Evaluate the query
-        if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+        auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
+        if (logits.empty()) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return;
         }
 
-        auto query_logits = llama_get_logits(ctx);
-
-        std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
+        std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
         const auto first_probs = softmax(tok_logits);
 
         hs_data[task_idx].ending_logprob_count[0] = 1;
@@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         // Calculate the logprobs over the ending
         for (size_t j = context_size; j < query_size - 1; j++) {
 
-            std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
+            std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
 
             const float prob = softmax(tok_logits)[query_embd[j + 1]];
 
@@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             // Tokenize the query
             query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
             query_size = query_embd.size();
-            //printf("Second query: %d\n",(int)query_size);
 
             // Stop if query wont fit the ctx window
             if (context_size + query_size > (size_t)params.n_ctx) {
@@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             //}
 
             // Evaluate the query
-            if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
+            logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
+            if (logits.empty()) {
                 fprintf(stderr, "%s : failed to eval\n", __func__);
                 return;
             }
 
-            query_logits = llama_get_logits(ctx);
-
             hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
             hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
 
             // Calculate the logprobs over the ending
             for (size_t j = 0; j < query_size - 1; j++) {
-                std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
+                std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
 
                 const float prob = softmax(tok_logits)[query_embd[j + 1]];