]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity: add proper batching (#19661)
authorAesSedai <redacted>
Mon, 16 Feb 2026 16:44:44 +0000 (08:44 -0800)
committerGitHub <redacted>
Mon, 16 Feb 2026 16:44:44 +0000 (18:44 +0200)
tools/perplexity/perplexity.cpp

index 1ead9c871e995a3390ae82b286ae6bcf6db3762c..433b747f0d457b84e5bf5cf1331bccb20f018476 100644 (file)
@@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
     int count = 0;
     double nll = 0.0;
 
-    LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+    const int n_seq = std::max(1, n_batch / n_ctx);
+    LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
 
     for (int i = 0; i < n_chunk; ++i) {
         const int start =     i * params.ppl_stride;
@@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
     }
 
     const int n_batch = params.n_batch;
-    const int num_batches = (n_ctx + n_batch - 1)/n_batch;
+    const int num_batches = (static_cast<int>(n_ctx) + n_batch - 1) / n_batch;
+    // Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports
+    const int n_seq_max = llama_n_seq_max(ctx);
+    int n_seq = std::max(1, n_batch / static_cast<int>(n_ctx));
+    if (n_seq > n_seq_max) {
+        LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n",
+                __func__, n_seq, n_seq_max, n_seq_max);
+        n_seq = n_seq_max;
+    }
     const int nv = 2*((n_vocab + 1)/2) + 4;
     const bool add_bos = llama_vocab_get_add_bos(vocab);
     GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
 
+    llama_batch batch = llama_batch_init(std::min(n_batch, static_cast<int>(n_ctx)*n_seq), 0, 1);
+
     std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
     std::vector<float>    kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
     std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
         logits.reserve(size_t(n_ctx) * n_vocab);
     }
 
+    LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
+
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 
     auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
@@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
     auto    kld_ptr =    kld_values.data();
     auto p_diff_ptr = p_diff_values.data();
 
-    for (int i = 0; i < n_chunk; ++i) {
+    const int first = n_ctx/2;
+
+    for (int i = 0; i < n_chunk; i += n_seq) {
         const int start =     i * n_ctx;
         const int end   = start + n_ctx;
 
-        const auto t_start = std::chrono::high_resolution_clock::now();
+        const int n_seq_batch = std::min(n_seq, n_chunk - i);
 
-        if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
-            LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
-            return;
-        }
+        const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
         llama_memory_clear(llama_get_memory(ctx), true);
 
-        llama_batch batch = llama_batch_init(n_batch, 0, 1);
-
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
             const int batch_size  = std::min(end - batch_start, n_batch);
 
-            // save original token and restore it after eval
-            const auto token_org = tokens[batch_start];
-
-            // add BOS token for the first batch of each chunk
-            if (add_bos && j == 0) {
-                tokens[batch_start] = llama_vocab_bos(vocab);
-            }
+            int n_outputs = 0;
 
             common_batch_clear(batch);
-            for (int i = 0; i < batch_size; i++) {
-                common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
+            for (int seq = 0; seq < n_seq_batch; seq++) {
+                int seq_start = batch_start + seq*n_ctx;
+
+                // save original token and restore it after eval
+                const auto token_org = tokens[seq_start];
+
+                // add BOS token for the first batch of each chunk
+                if (add_bos && j == 0) {
+                    tokens[seq_start] = llama_vocab_bos(vocab);
+                }
+
+                for (int k = 0; k < batch_size; ++k) {
+                    const int pos = j*n_batch + k;
+                    const bool need_logits = pos >= first;
+                    common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits);
+                    n_outputs += need_logits;
+                }
+
+                // restore the original token in case it was set to BOS
+                tokens[seq_start] = token_org;
             }
 
             if (llama_decode(ctx, batch)) {
-                LOG_ERR("%s : failed to eval\n", __func__);
+                LOG_ERR("%s : failed to decode\n", __func__);
                 llama_batch_free(batch);
                 return;
             }
 
-            // restore the original token in case it was set to BOS
-            tokens[batch_start] = token_org;
-
-            if (num_batches > 1) {
+            if (num_batches > 1 && n_outputs > 0) {
                 const auto * batch_logits = llama_get_logits(ctx);
-                logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
+                logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
             }
         }
 
-        llama_batch_free(batch);
-
-        const auto t_end = std::chrono::high_resolution_clock::now();
-
         if (i == 0) {
+            llama_synchronize(ctx);
+            const auto t_end = std::chrono::high_resolution_clock::now();
             const float t_total = std::chrono::duration<float>(t_end - t_start).count();
             LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
-            int total_seconds = (int)(t_total * n_chunk);
+            int total_seconds = (int)(t_total * n_chunk / n_seq);
             if (total_seconds >= 60*60) {
                 LOG("%d hours ", total_seconds / (60*60));
                 total_seconds = total_seconds % (60*60);
             }
             LOG("%.2f minutes\n", total_seconds / 60.0);
+            LOG("\n");
+            LOG("chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p\n");
         }
-        LOG("\n");
-        LOG("chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p\n");
 
-        const int first = n_ctx/2;
-        const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
-        process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
-        p_diff_ptr += n_ctx - 1 - first;
-        kld_ptr    += n_ctx - 1 - first;
+        // Read log probs for each sequence in the batch
+        for (int seq = 0; seq < n_seq_batch; seq++) {
+            if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
+                LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq);
+                llama_batch_free(batch);
+                return;
+            }
 
-        LOG("%4d", i+1);
+            const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
 
-        auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
-        const double ppl_val = exp(log_ppl.first);
-        const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
-        LOG("    %9.4lf ± %9.4lf", ppl_val, ppl_unc);
+            process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first,
+                    workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
+            p_diff_ptr += n_ctx - 1 - first;
+            kld_ptr    += n_ctx - 1 - first;
 
-        auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
-        const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
-        const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
-        const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
-        LOG("    %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
+            LOG("%4d", i + seq + 1);
 
-        auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
-        LOG("    %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
+            auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
+            const double ppl_val = exp(log_ppl.first);
+            const double ppl_unc = ppl_val * log_ppl.second;
+            LOG("    %9.4lf ± %9.4lf", ppl_val, ppl_unc);
 
-        auto p_diff_mse   = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
-        const double p_diff_rms_val = sqrt(p_diff_mse.first);
-        const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
-        LOG("    %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
+            auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
+            const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
+            const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
+            const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
+            LOG("    %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
 
-        double p_top_val = 1.*kld.n_same_top/kld.count;
-        double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
-        LOG("    %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
+            auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
+            LOG("    %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
 
-        LOG("\n");
+            auto p_diff_mse   = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
+            const double p_diff_rms_val = sqrt(p_diff_mse.first);
+            const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
+            LOG("    %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
+
+            double p_top_val = 1.*kld.n_same_top/kld.count;
+            double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
+            LOG("    %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
+
+            LOG("\n");
+        }
 
         logits.clear();
     }
+
+    llama_batch_free(batch);
     LOG("\n");
 
     if (kld.count < 100) return; // we do not wish to do statistics on so few values
@@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) {
 
     const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
 
-    if (ppl) {
+    if (ppl || params.kl_divergence) {
         const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
         const int32_t n_kv = n_seq * n_ctx;
 
@@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) {
         params.n_batch = std::min(params.n_batch, n_kv);
     } else {
         params.n_batch = std::min(params.n_batch, params.n_ctx);
-        if (params.kl_divergence) {
-            params.n_parallel = 1;
-        } else {
-            // ensure there's at least enough seq_ids for HellaSwag
-            params.n_parallel = std::max(4, params.n_parallel);
-        }
+        // ensure there's at least enough seq_ids for HellaSwag
+        params.n_parallel = std::max(4, params.n_parallel);
     }
 
     if (params.ppl_stride > 0) {