]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Slightly faster imatrix (#5050)
authorKawrakow <redacted>
Sun, 21 Jan 2024 06:01:20 +0000 (08:01 +0200)
committerGitHub <redacted>
Sun, 21 Jan 2024 06:01:20 +0000 (08:01 +0200)
* imatrix: speedup by avoiding unnecessary allocations and copies

* imatrix: add --no-ppl option to skip PPL calculations altogether

---------

Co-authored-by: Iwan Kawrakow <redacted>
examples/imatrix/imatrix.cpp

index 5a3d30b888d0318c80b2cbbaa06704ae4665d51e..5687476cdcf92798a07eeefc96dc3ded44d3966e 100644 (file)
@@ -248,7 +248,7 @@ static void process_logits(
     }
 }
 
-static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
+static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
 
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
     const int n_ctx = llama_n_ctx(ctx);
@@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
     }
 
     std::vector<float> logit_history;
-    logit_history.resize(tokens.size());
-
     std::vector<float> prob_history;
-    prob_history.resize(tokens.size());
+
+    if (compute_ppl) {
+        logit_history.resize(tokens.size());
+        prob_history.resize(tokens.size());
+    }
 
     const int n_chunk_max = tokens.size() / n_ctx;
 
@@ -288,12 +290,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
 
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 
+    const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+
+    std::vector<float> logits;
+    if (compute_ppl && num_batches > 1) {
+        logits.reserve((size_t)n_ctx * n_vocab);
+    }
+
     for (int i = 0; i < n_chunk; ++i) {
         const int start =     i * n_ctx;
         const int end   = start + n_ctx;
 
-        const int num_batches = (n_ctx + n_batch - 1) / n_batch;
-
         std::vector<float> logits;
 
         const auto t_start = std::chrono::high_resolution_clock::now();
@@ -321,8 +328,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
             // restore the original token in case it was set to BOS
             tokens[batch_start] = token_org;
 
-            const auto * batch_logits = llama_get_logits(ctx);
-            logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            if (compute_ppl && num_batches > 1) {
+                const auto * batch_logits = llama_get_logits(ctx);
+                logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            }
         }
 
         const auto t_end = std::chrono::high_resolution_clock::now();
@@ -338,25 +347,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
             fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
         }
 
-        const int first = n_ctx/2;
-        process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                       workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
-        count += n_ctx - first - 1;
+        if (compute_ppl) {
+            const int first = n_ctx/2;
+            const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
+            process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+                    workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
+            count += n_ctx - first - 1;
+
+            printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+            fflush(stdout);
 
-        printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
-        fflush(stdout);
+            logits.clear();
+        }
     }
     printf("\n");
 
-    nll2 /= count;
-    nll /= count;
-    const double ppl = exp(nll);
-    nll2 -= nll * nll;
-    if (nll2 > 0) {
-        nll2 = sqrt(nll2/(count-1));
-        printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
-    } else {
-        printf("Unexpected negative standard deviation of log(prob)\n");
+    if (compute_ppl) {
+        nll2 /= count;
+        nll /= count;
+        const double ppl = exp(nll);
+        nll2 -= nll * nll;
+        if (nll2 > 0) {
+            nll2 = sqrt(nll2/(count-1));
+            printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
+        } else {
+            printf("Unexpected negative standard deviation of log(prob)\n");
+        }
     }
 
     return true;
@@ -365,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
 int main(int argc, char ** argv) {
 
     StatParams sparams;
+    bool compute_ppl = true;
     std::vector<char*> args;
     args.push_back(argv[0]);
     int iarg = 1;
@@ -381,12 +398,19 @@ int main(int argc, char ** argv) {
         }
         else if (arg == "--verbosity") {
             sparams.verbosity = std::stoi(argv[++iarg]);
+        } else if (arg == "--no-ppl") {
+            compute_ppl = false;
         } else {
             args.push_back(argv[iarg]);
         }
     }
     if (iarg < argc) {
-        args.push_back(argv[iarg]);
+        std::string arg{argv[iarg]};
+        if (arg == "--no-ppl") {
+            compute_ppl = false;
+        } else {
+            args.push_back(argv[iarg]);
+        }
     }
 
     gpt_params params;
@@ -448,7 +472,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s\n", get_system_info(params).c_str());
     }
 
-    bool OK = compute_imatrix(ctx, params);
+    bool OK = compute_imatrix(ctx, params, compute_ppl);
     if (!OK) {
         return 1;
     }