]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : faster Winogrande via batching (#5024)
authorGeorgi Gerganov <redacted>
Fri, 19 Jan 2024 08:45:06 +0000 (10:45 +0200)
committerGitHub <redacted>
Fri, 19 Jan 2024 08:45:06 +0000 (10:45 +0200)
* perplexity : faster Winogrande via batching

ggml-ci

* perplexity : remove unused function

* perplexity : only tokenize selected tasks for Winogrande

examples/perplexity/perplexity.cpp

index f72ea6d1caea55e7af2934270da4a15b4aabc768..df902fb1c9a8317bdfa47a661a6ed45c3709109f 100644 (file)
@@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     return {tokens, ppl, logit_history, prob_history};
 }
 
-static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
-        int n_past, int n_batch, int n_vocab) {
-    std::vector<float> result;
-    result.reserve(tokens.size() * n_vocab);
-    size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
-    for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
-        size_t n_tokens = tokens.size() - i_chunk * n_batch;
-        n_tokens = std::min(n_tokens, size_t(n_batch));
-        llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
-        if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
-            return {};
+static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
+    for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
+        const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+
+        llama_batch batch_view = {
+            n_tokens,
+            batch.token    + i,
+            nullptr,
+            batch.pos      + i,
+            batch.n_seq_id + i,
+            batch.seq_id   + i,
+            batch.logits   + i,
+            0, 0, 0, // unused
+        };
+
+        const int ret = llama_decode(ctx, batch_view);
+        if (ret != 0) {
+            LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
+            return false;
         }
 
-        const auto logits = llama_get_logits(ctx);
-        result.insert(result.end(), logits, logits + n_tokens * n_vocab);
-
-        n_past += n_tokens;
+        memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
     }
-    return result;
+
+    return true;
 }
 
 static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
@@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
 
         // determine the common prefix of the endings
         hs_cur.common_prefix = 0;
-        hs_cur.required_tokens = 0;
         for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
             if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
                 hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
@@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     const int n_ctx   = llama_n_ctx(ctx);
     const int n_batch = params.n_batch;
 
-    const int max_tasks_per_batch = params.n_parallel;
+    const int max_tasks_per_batch = 32;
     const int max_seq = 4*max_tasks_per_batch;
 
     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
     std::vector<float> tok_logits(n_vocab);
-    std::vector<float> batch_logits(n_ctx*n_vocab);
+    std::vector<float> batch_logits(n_vocab*n_ctx);
 
     std::vector<std::pair<size_t, llama_token>> eval_pairs;
     std::vector<float> eval_results;
     std::vector<std::thread> workers(std::thread::hardware_concurrency());
 
-    auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
-        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
-            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
-
-            llama_batch batch_view = {
-                n_tokens,
-                batch.token    + i,
-                nullptr,
-                batch.pos      + i,
-                batch.n_seq_id + i,
-                batch.seq_id   + i,
-                batch.logits   + i,
-                0, 0, 0, // unused
-            };
-
-            const int ret = llama_decode(ctx, batch_view);
-            if (ret != 0) {
-                LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
-                return false;
-            }
-
-            memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
-        }
-
-        return true;
-    };
-
     for (size_t i0 = 0; i0 < hs_task_count; i0++) {
         int n_cur = 0;
 
@@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         llama_kv_cache_clear(ctx);
 
         // decode all tasks [i0, i1)
-        if (!decode_helper(ctx, batch, n_batch)) {
+        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
             fprintf(stderr, "%s: llama_decode() failed\n", __func__);
             return;
         }
@@ -772,6 +749,13 @@ struct winogrande_entry {
     std::string second;
     std::array<std::string, 2> choices;
     int answer;
+
+    size_t i_batch;
+    size_t common_prefix;
+    size_t required_tokens;
+    size_t n_base1; // number of tokens for context + choice 1
+    size_t n_base2; // number of tokens for context + choice 2
+    std::vector<llama_token> seq_tokens[2];
 };
 
 static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
@@ -875,115 +859,164 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
         data = std::move(selected);
     }
 
+    fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
+
     // This is needed as usual for LLaMA models
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
+    for (auto & task : data) {
+        task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
+        task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
+
+        task.common_prefix = 0;
+        for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
+            if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
+                break;
+            }
+            task.common_prefix++;
+        }
+
+        task.required_tokens = task.common_prefix +
+            task.seq_tokens[0].size() - task.common_prefix +
+            task.seq_tokens[1].size() - task.common_prefix;
+
+        task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
+        task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
+    }
+
     fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
 
     const int n_vocab = llama_n_vocab(llama_get_model(ctx));
-    const int n_ctx = llama_n_ctx(ctx);
+    const int n_ctx   = llama_n_ctx(ctx);
+    const int n_batch = params.n_batch;
+
+    const int max_tasks_per_batch = 128;
+    const int max_seq = 2*max_tasks_per_batch;
+
+    llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
     std::vector<float> tok_logits(n_vocab);
+    std::vector<float> batch_logits(n_vocab*n_ctx);
 
     int n_correct = 0;
     int n_done    = 0;
 
-    for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
-        const auto& task = data[task_idx];
+    for (size_t i0 = 0; i0 < data.size(); i0++) {
+        int n_cur = 0;
 
-        auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
-        auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
-        auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
+        size_t i1 = i0;
+        size_t i_batch = 0;
 
-        auto sentence_1st = task.first + task.choices[0] + task.second;
-        auto sentence_2nd = task.first + task.choices[1] + task.second;
-        auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
-        auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
+        llama_batch_clear(batch);
 
-        if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
-            fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
-            return;
-        }
+        while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
+            const int s0 = 2*(i1 - i0);
+            if (s0 + 2 > max_seq) {
+                break;
+            }
+
+            for (size_t i = 0; i < data[i1].common_prefix; ++i) {
+                llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
+            }
+            batch.logits[batch.n_tokens - 1] = true;
 
-        auto query_1st_size = query_1st.size();
-        auto query_2nd_size = query_2nd.size();
+            for (int s = 0; s < 2; ++s) {
+                for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
+                    llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
+                }
+            }
 
-        // Speedup small evaluations by evaluating atleast 32 tokens
-        // For Winogrande this seems to slow it down rather than speed it up.
-        //if (query_1st.size() < 32) query_1st.resize(32);
-        //if (query_2nd.size() < 32) query_2nd.resize(32);
+            data[i1].i_batch = i_batch;
+            i_batch += data[i1].required_tokens;
 
-        llama_kv_cache_clear(ctx);
-        auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
+            n_cur += data[i1].required_tokens;
+            if (++i1 == data.size()) {
+                break;
+            }
+        }
+
+        if (i0 == i1) {
+            fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
+            return;
+        }
 
         llama_kv_cache_clear(ctx);
-        auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);
 
-        if (logits_1st.empty() || logits_2nd.empty()) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
+        // decode all tasks [i0, i1)
+        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
+            fprintf(stderr, "%s: llama_decode() failed\n", __func__);
             return;
         }
 
-        bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
-                           query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;
-
-        float score_1st = 0;
-        bool is_nan_1st = false;
-        const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
-        const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
-        for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
-            std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
-            const float prob = softmax(tok_logits)[query_1st[j+1]];
-            if (std::isnan(prob) || !prob) {
-                fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
-                        prob, j, sentence_1st.c_str(), base_context.size());
-                is_nan_1st = true;
-                break;
+        for (size_t i = i0; i < i1; ++i) {
+            auto & task = data[i];
+
+            const bool skip_choice =
+                task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
+                task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
+
+            float score_1st = 0;
+            bool is_nan_1st = false;
+            const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
+            const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
+            size_t li = n_base1 - 1;
+            for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
+                std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
+                const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
+                if (std::isnan(prob) || !prob) {
+                    fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
+                            prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
+                    is_nan_1st = true;
+                    break;
+                }
+                score_1st += std::log(prob);
             }
-            score_1st += std::log(prob);
-        }
-        score_1st /= (query_1st_size - base_1.size() - last_1st);
-
-        float score_2nd = 0;
-        bool is_nan_2nd = false;
-        const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
-        const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
-        for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
-            std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
-            const float prob = softmax(tok_logits)[query_2nd[j+1]];
-            if (std::isnan(prob) || !prob) {
-                fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
-                        prob, j, sentence_2nd.c_str(), base_context.size());
-                is_nan_2nd = true;
-                break;
+            score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
+
+            float score_2nd = 0;
+            bool is_nan_2nd = false;
+            const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
+            const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
+            li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
+            for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
+                std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
+                const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
+                if (std::isnan(prob) || !prob) {
+                    fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
+                            prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
+                    is_nan_2nd = true;
+                    break;
+                }
+                score_2nd += std::log(prob);
             }
-            score_2nd += std::log(prob);
-        }
-        score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
+            score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
 
-        if (is_nan_1st || is_nan_2nd) {
-            continue;
-        }
+            if (is_nan_1st || is_nan_2nd) {
+                continue;
+            }
 
-        if (std::isnan(score_1st) || std::isnan(score_2nd)) {
-            printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
-            printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
-            printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
-            printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
-            printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
-            continue;
-        }
+            if (std::isnan(score_1st) || std::isnan(score_2nd)) {
+                printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
+                printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
+                printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
+                printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
+                printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
+                continue;
+            }
 
-        int result = score_1st > score_2nd ? 1 : 2;
+            int result = score_1st > score_2nd ? 1 : 2;
+
+            if (result == task.answer) {
+                ++n_correct;
+            }
+            ++n_done;
 
-        if (result == task.answer) {
-            ++n_correct;
+            // Print the accumulated accuracy mean x 100
+            printf("%zu\t%.4lf\t%10.6f  %10.6f  %d  %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
+            fflush(stdout);
         }
-        ++n_done;
 
-        // Print the accumulated accuracy mean x 100
-        printf("%zu\t%.4lf\t%10.6f  %10.6f  %d  %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
-        fflush(stdout);
+        i0 = i1 - 1;
     }
 
     printf("\n");