]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : fix integer overflow (#9783)
authorGeorgi Gerganov <redacted>
Wed, 9 Oct 2024 14:00:18 +0000 (17:00 +0300)
committerGitHub <redacted>
Wed, 9 Oct 2024 14:00:18 +0000 (17:00 +0300)
* perplexity : fix integer overflow

ggml-ci

* perplexity : keep n_vocab as int and make appropriate casts

ggml-ci

examples/perplexity/perplexity.cpp

index 87347135e0bb76da963dd452c0f8faf6b31839ac..40bc29f7a870c0ff9c3aa6654fe070dffb21e719 100644 (file)
@@ -169,7 +169,7 @@ static void process_logits(
                 break;
             }
             lock.unlock();
-            const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
+            const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]);
             const double v = -results.log_softmax;
             local_nll += v;
             local_nll2 += v*v;
@@ -203,7 +203,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
                 break;
             }
             lock.unlock();
-            const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
+            const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
             local_nll += v;
             local_nll2 += v*v;
         }
@@ -281,7 +281,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
     kld.sum_kld  += sum;
     kld.sum_kld2 += sum*sum;
     ++kld.count;
-    if (imax == imax_base) ++kld.n_same_top;
+    if (imax == imax_base) {
+        ++kld.n_same_top;
+    }
 
     const float p_base = expf(-nll_base);
     const float p = expf(-nll);
@@ -323,7 +325,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
                 break;
             }
             lock.unlock();
-            std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
+            std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
             kld_values[i]    = (float)v.first;
             p_diff_values[i] = v.second;
         }
@@ -383,9 +385,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
     const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1)  / params.ppl_stride;
 
     const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_batch = params.n_batch;
 
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
     int count = 0;
     double nll = 0.0;
 
@@ -424,8 +427,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
                 tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
             }
 
-            const auto batch_logits = llama_get_logits(ctx);
-            logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            const auto batch_logits = llama_get_logits(ctx);
+            logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
 
             if (j == 0) {
                 tokens[batch_start] = token_org;
@@ -447,11 +450,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
 
         //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
         for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
-
             // Calculate probability of next token, given the previous ones.
             const std::vector<float> tok_logits(
-                logits.begin() + (j + 0) * n_vocab,
-                logits.begin() + (j + 1) * n_vocab);
+                logits.begin() + size_t(j + 0) * n_vocab,
+                logits.begin() + size_t(j + 1) * n_vocab);
 
             const float prob = softmax(tok_logits)[tokens[start + j + 1]];
             logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
@@ -521,9 +523,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     const int n_chunk_max = tokens.size() / n_ctx;
 
     const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_batch = params.n_batch;
 
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
     int count = 0;
     double nll = 0.0;
     double nll2 = 0.0;
@@ -538,7 +541,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
 
     std::vector<float> logits;
     if (num_batches > 1) {
-        logits.reserve((size_t)n_ctx * n_vocab);
+        logits.reserve(size_t(n_ctx) * n_vocab);
     }
 
     LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
@@ -620,7 +623,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
 
             if (num_batches > 1 && n_outputs > 0) {
                 const auto * batch_logits = llama_get_logits(ctx);
-                logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
+                logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
             }
         }
 
@@ -661,7 +664,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             } else {
                 double av = nll/count;
                 double av2 = nll2/count - av*av;
-                if (av2 > 0) av2 = sqrt(av2/(count-1));
+                if (av2 > 0) {
+                    av2 = sqrt(av2/(count-1));
+                }
                 LOG("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
             }
         }
@@ -686,10 +691,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     return {tokens, ppl, logit_history, prob_history};
 }
 
-static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
+static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
     int prev_outputs = 0;
-    for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
-        const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+    for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
+        const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
 
         llama_batch batch_view = {
             n_tokens,
@@ -713,7 +718,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
             n_outputs += batch_view.logits[i] != 0;
         }
 
-        memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
+        memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
 
         prev_outputs += n_outputs;
     }
@@ -728,7 +733,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
     if (eval_results.size() != eval_pairs.size()) {
         eval_results.resize(eval_pairs.size());
     }
-    if (eval_pairs.empty()) return;
+    if (eval_pairs.empty()) {
+        return;
+    }
 
     size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
 
@@ -736,11 +743,13 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
     auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
         float local_logprobs[K_TOKEN_CHUNK];
         while (true) {
-            size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
-            if (first >= eval_results.size()) break;
-            size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
+            const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
+            if (first >= eval_results.size()) {
+                break;
+            }
+            const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
             for (size_t i = first; i < last; ++i) {
-                auto logits = batch_logits + eval_pairs[i].first * n_vocab;
+                const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
                 float max_logit = logits[0];
                 for (int j = 1; j < n_vocab; ++j) {
                     max_logit = std::max(max_logit, logits[j]);
@@ -877,10 +886,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
 
     double acc = 0.0f;
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_ctx   = llama_n_ctx(ctx);
     const int n_batch = params.n_batch;
 
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
     const int max_tasks_per_batch = 32;
     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
@@ -888,7 +898,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
 
     std::vector<float> tok_logits(n_vocab);
     // TODO: this could be made smaller; it's currently the worst-case size
-    std::vector<float> batch_logits(n_vocab*n_ctx);
+    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 
     std::vector<std::pair<size_t, llama_token>> eval_pairs;
     std::vector<float> eval_results;
@@ -975,7 +985,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             auto & hs_cur = hs_data[i];
 
             // get the logits of the last token of the common prefix
-            std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
+            std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float));
 
             const auto first_probs = softmax(tok_logits);
 
@@ -1158,10 +1168,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
 
     LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_ctx   = llama_n_ctx(ctx);
     const int n_batch = params.n_batch;
 
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
     const int max_tasks_per_batch = 128;
     const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
@@ -1169,7 +1180,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
 
     std::vector<float> tok_logits(n_vocab);
     // TODO: this could be made smaller; it's currently the worst-case size
-    std::vector<float> batch_logits(n_vocab*n_ctx);
+    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 
     std::vector<std::pair<size_t, llama_token>> eval_pairs;
     std::vector<float> eval_results;
@@ -1509,17 +1520,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
 
     LOG("\ntask\tacc_norm\n");
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_ctx   = llama_n_ctx(ctx);
     const int n_batch = params.n_batch;
 
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
     const int max_tasks_per_batch = 32;
     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 
     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
 
     std::vector<float> tok_logits(n_vocab);
-    std::vector<float> batch_logits(n_vocab*n_ctx);
+    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 
     std::vector<std::pair<size_t, llama_token>> eval_pairs;
     std::vector<float> eval_results;
@@ -1627,7 +1639,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
             //LOG("\n    common_prefix: %zu\n", cur_task.common_prefix);
 
             // get the logits of the last token of the common prefix
-            std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
+            std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float));
 
             const auto first_probs = softmax(tok_logits);
 
@@ -1709,7 +1721,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
                 __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
     }
 
-    int n_vocab, n_chunk;
+    int n_vocab;
+    int n_chunk;
     in.read((char *)&n_vocab, sizeof(n_vocab));
     in.read((char *)&n_chunk, sizeof(n_chunk));
     if (in.fail()) {
@@ -1720,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
         LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
     }
 
-    std::vector<llama_token> tokens(n_ctx * n_chunk);
+    std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
     if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
         LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
         return;
@@ -1737,7 +1750,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
     std::vector<float> logits;
     if (num_batches > 1) {
-        logits.reserve(n_ctx * n_vocab);
+        logits.reserve(size_t(n_ctx) * n_vocab);
     }
 
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@@ -1801,7 +1814,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
 
             if (num_batches > 1) {
                 const auto * batch_logits = llama_get_logits(ctx);
-                logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+                logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
             }
         }
 
@@ -1822,7 +1835,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
 
         const int first = n_ctx/2;
         const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
-        process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+        process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
                 workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
         p_diff_ptr += n_ctx - 1 - first;
         kld_ptr    += n_ctx - 1 - first;