]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : faster HellaSwag via batching (#5017)
authorGeorgi Gerganov <redacted>
Thu, 18 Jan 2024 13:33:01 +0000 (15:33 +0200)
committerGitHub <redacted>
Thu, 18 Jan 2024 13:33:01 +0000 (15:33 +0200)
* perplexity : faster HellaSwag

ggml-ci

* perplexity : clean-up

ggml-ci

* perplexity : no need for decode_helper

ggml-ci

* perplexity : add comments

* perplexity : option to specify max batched tasks via `n_parallel`

* perplexity : remove HellaSwag restruction for n_batch

examples/perplexity/perplexity.cpp

index 57eaa713ed616759be890991f4e6f9a46b8c6bce..ea2c8026cfcecfdbe5647a172e12d192dce95c09 100644 (file)
@@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         prompt_lines.push_back(line);
     }
 
-    ifprompt_lines.size() % 6 != 0) {
+    if (prompt_lines.size() % 6 != 0) {
         fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
         return;
     }
@@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
     // Number of tasks to use when computing the score
-    if ( params.hellaswag_tasks < hs_task_count  ) {
+    if (params.hellaswag_tasks < hs_task_count) {
         hs_task_count = params.hellaswag_tasks;
     }
 
@@ -502,27 +502,54 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         std::string ending[4];
         size_t ending_logprob_count[4];
         double ending_logprob[4];
+
+        size_t i_batch;         // starting index in the llama_batch
+        size_t common_prefix;   // max number of initial tokens that are the same in all sentences
+        size_t required_tokens; // needed number of tokens to evaluate all 4 endings
+        std::vector<llama_token> seq_tokens[4];
     };
 
     fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first")  );
 
     // Select and read data from prompt lines
-    hs_data_t *hs_data = new hs_data_t[hs_task_count];
-    for (size_t i=0; i < hs_task_count; i++) {
+    std::vector<hs_data_t> hs_data(hs_task_count);
+    for (size_t i = 0; i < hs_task_count; i++) {
         size_t idx = i;
 
+        auto & hs_cur = hs_data[i];
+
         // Select a random example of those left in the prompt
         if (randomize_tasks) {
             std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
             idx = dist(rng);
         }
 
-        hs_data[i].context = prompt_lines[idx*6];
-        hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
-        for (size_t j=0; j < 4; j++) {
-            hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
+        hs_cur.context = prompt_lines[idx*6];
+        hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
+        for (size_t j = 0; j < 4; j++) {
+            hs_cur.ending[j] = prompt_lines[idx*6+2+j];
+            hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
         }
 
+        // determine the common prefix of the endings
+        hs_cur.common_prefix = 0;
+        hs_cur.required_tokens = 0;
+        for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
+            if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
+                hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
+                hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
+                break;
+            }
+            hs_cur.common_prefix++;
+        }
+        hs_cur.required_tokens = hs_cur.common_prefix +
+            hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
+            hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
+            hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
+            hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
+
+        //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
+
         // Delete the selected random example from the prompt
         if (randomize_tasks) {
             prompt_lines.erase( std::next(prompt_lines.begin(),idx*6)  , std::next(prompt_lines.begin(),idx*6+6) );
@@ -530,150 +557,160 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     }
 
     fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
+
     printf("\ntask\tacc_norm\n");
 
     double acc = 0.0f;
+
     const int n_vocab = llama_n_vocab(llama_get_model(ctx));
-    const int n_ctx = llama_n_ctx(ctx);
+    const int n_ctx   = llama_n_ctx(ctx);
+    const int n_batch = params.n_batch;
 
-    std::vector<std::vector<int>> ending_tokens(4);
+    const int max_tasks_per_batch = params.n_parallel;
+    const int max_seq = 4*max_tasks_per_batch;
 
-    std::vector<float> tok_logits(n_vocab);
+    llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
-    for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
-        // Tokenize the context to count tokens
-        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
-        size_t context_size = context_embd.size();
-
-        for (int i = 0; i < 4; ++i) {
-            ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
-            for (int k = 0; k < int(context_size); ++k) {
-                if (ending_tokens[i][k] != context_embd[k]) {
-                    fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
-                    break;
-                }
+    std::vector<float> tok_logits(n_vocab);
+    std::vector<float> batch_logits(n_ctx*n_vocab);
+
+    auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
+        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
+            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+
+            llama_batch batch_view = {
+                n_tokens,
+                batch.token    + i,
+                nullptr,
+                batch.pos      + i,
+                batch.n_seq_id + i,
+                batch.seq_id   + i,
+                batch.logits   + i,
+                0, 0, 0, // unused
+            };
+
+            const int ret = llama_decode(ctx, batch_view);
+            if (ret != 0) {
+                LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
+                return false;
             }
-        }
 
-        // Do the 1st ending
-        // In this case we include the context when evaluating
-        //auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
-        auto query_embd = ending_tokens[0];
-        auto query_size = query_embd.size();
-
-        // Stop if query wont fit the ctx window
-        if (query_size > (size_t)n_ctx) {
-            fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
-            return;
+            memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
         }
 
-        // Speedup small evaluations by evaluating atleast 32 tokens
-        if (query_size < 32) {
-            query_embd.resize(32);
-        }
+        return true;
+    };
 
-        // clear the KV cache
-        llama_kv_cache_clear(ctx);
+    for (size_t i0 = 0; i0 < hs_task_count; i0++) {
+        int n_cur = 0;
 
-        auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
-        if (logits.empty()) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
-            return;
-        }
+        size_t i1 = i0;
+        size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+
+        llama_batch_clear(batch);
 
-        std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
-        const auto first_probs = softmax(tok_logits);
+        // batch as much tasks as possible into the available context
+        // each task has 4 unique seuqnce ids - one for each ending
+        // the common prefix is shared among the 4 sequences to save tokens
+        // we extract logits only from the last common token and from all ending tokens of each sequence
+        while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
+            auto & hs_cur = hs_data[i1];
 
-        hs_data[task_idx].ending_logprob_count[0] = 1;
-        hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
+            const int s0 = 4*(i1 - i0);
+            if (s0 + 4 > max_seq) {
+                break;
+            }
 
-        // Calculate the logprobs over the ending
-        for (size_t j = context_size; j < query_size - 1; j++) {
+            for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
+                llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
+            }
+            batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
 
-            std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
+            for (int s = 0; s < 4; ++s) {
+                for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
+                    llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
+                }
+            }
 
-            const float prob = softmax(tok_logits)[query_embd[j + 1]];
+            hs_cur.i_batch = i_batch;
+            i_batch += hs_cur.required_tokens;
 
-            hs_data[task_idx].ending_logprob[0] += std::log(prob);
-            hs_data[task_idx].ending_logprob_count[0]++;
+            n_cur += hs_data[i1].required_tokens;
+            if (++i1 == hs_task_count) {
+                break;
+            }
         }
 
-        // Calculate the mean token logprob for acc_norm
-        hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
+        if (i0 == i1) {
+            fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
+            return;
+        }
 
-        // Do the remaining endings
-        // For these, we use the bare ending with n_past = context_size
-        //
-        for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
+        llama_kv_cache_clear(ctx);
 
-            // Tokenize the query
-            query_embd.resize(ending_tokens[ending_idx].size() - context_size);
-            std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
-            query_size = query_embd.size();
+        // decode all tasks [i0, i1)
+        if (!decode_helper(ctx, batch, n_batch)) {
+            fprintf(stderr, "%s: llama_decode() failed\n", __func__);
+            return;
+        }
 
-            // Stop if query wont fit the ctx window
-            if (context_size + query_size > (size_t)n_ctx) {
-                fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
-                return;
-            }
+        // compute the logprobs for each ending of the decoded tasks
+        for (size_t i = i0; i < i1; ++i) {
+            auto & hs_cur = hs_data[i];
 
-            // Speedup small evaluations by evaluating atleast 32 tokens
-            // No, resizing to 32 is actually slightly slower (at least on CUDA)
-            //if (query_size < 32) {
-            //    query_embd.resize(32);
-            //}
+            std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
 
-            // Evaluate the query
-            logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
-            if (logits.empty()) {
-                fprintf(stderr, "%s : failed to eval\n", __func__);
-                return;
-            }
+            const auto first_probs = softmax(tok_logits);
 
-            hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
-            hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
+            size_t li = hs_cur.common_prefix; // logits index in the batch
 
-            // Calculate the logprobs over the ending
-            for (size_t j = 0; j < query_size - 1; j++) {
-                std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
+            for (int s = 0; s < 4; ++s) {
+                hs_cur.ending_logprob_count[s] = 1;
+                hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
 
-                const float prob = softmax(tok_logits)[query_embd[j + 1]];
+                // Calculate the logprobs over the ending
+                for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
+                    std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
 
-                hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
-                hs_data[task_idx].ending_logprob_count[ending_idx]++;
-            }
+                    const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
 
-            // Calculate the mean token logprob for acc_norm
-            hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
+                    hs_cur.ending_logprob[s] += std::log(prob);
+                    hs_cur.ending_logprob_count[s]++;
+                }
 
+                // account that we skip the last token in the ending
+                ++li;
 
-//            printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
-//                task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
-        }
+                // Calculate the mean token logprob for acc_norm
+                hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
+            }
 
-        // Find the ending with maximum logprob
-        size_t ending_logprob_max_idx = 0;
-        double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
-        for (size_t j = 1; j < 4; j++) {
-            if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
-                ending_logprob_max_idx = j;
-                ending_logprob_max_val =  hs_data[task_idx].ending_logprob[j];
+            // Find the ending with maximum logprob
+            size_t ending_logprob_max_idx = 0;
+            double ending_logprob_max_val = hs_cur.ending_logprob[0];
+            for (size_t s = 1; s < 4; s++) {
+                if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
+                    ending_logprob_max_idx = s;
+                    ending_logprob_max_val =  hs_cur.ending_logprob[s];
+                }
             }
-        }
 
-//        printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
+            //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
+
+            // If the gold ending got the maximum logprobe add one accuracy point
+            if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
+                acc += 1.0;
+            }
 
-        // If the gold ending got the maximum logprobe add one accuracy point
-        if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
-            acc += 1.0;
+            // Print the accumulated accuracy mean x 100
+            printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
+            fflush(stdout);
         }
 
-        // Print the accumulated accuracy mean x 100
-        printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
-        fflush(stdout);
+        i0 = i1 - 1;
     }
 
-    delete [] hs_data;
+    llama_batch_free(batch);
 
     printf("\n");
 }