}
}
-static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
+static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
}
std::vector<float> logit_history;
- logit_history.resize(tokens.size());
-
std::vector<float> prob_history;
- prob_history.resize(tokens.size());
+
+ if (compute_ppl) {
+ logit_history.resize(tokens.size());
+ prob_history.resize(tokens.size());
+ }
const int n_chunk_max = tokens.size() / n_ctx;
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
+ const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+
+ std::vector<float> logits;
+ if (compute_ppl && num_batches > 1) {
+ logits.reserve((size_t)n_ctx * n_vocab);
+ }
+
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
- const int num_batches = (n_ctx + n_batch - 1) / n_batch;
-
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
- const auto * batch_logits = llama_get_logits(ctx);
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+ if (compute_ppl && num_batches > 1) {
+ const auto * batch_logits = llama_get_logits(ctx);
+ logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+ }
}
const auto t_end = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
- const int first = n_ctx/2;
- process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
- workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
- count += n_ctx - first - 1;
+ if (compute_ppl) {
+ const int first = n_ctx/2;
+ const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
+ process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+ workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
+ count += n_ctx - first - 1;
+
+ printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+ fflush(stdout);
- printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
- fflush(stdout);
+ logits.clear();
+ }
}
printf("\n");
- nll2 /= count;
- nll /= count;
- const double ppl = exp(nll);
- nll2 -= nll * nll;
- if (nll2 > 0) {
- nll2 = sqrt(nll2/(count-1));
- printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
- } else {
- printf("Unexpected negative standard deviation of log(prob)\n");
+ if (compute_ppl) {
+ nll2 /= count;
+ nll /= count;
+ const double ppl = exp(nll);
+ nll2 -= nll * nll;
+ if (nll2 > 0) {
+ nll2 = sqrt(nll2/(count-1));
+ printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
+ } else {
+ printf("Unexpected negative standard deviation of log(prob)\n");
+ }
}
return true;
int main(int argc, char ** argv) {
StatParams sparams;
+ bool compute_ppl = true;
std::vector<char*> args;
args.push_back(argv[0]);
int iarg = 1;
}
else if (arg == "--verbosity") {
sparams.verbosity = std::stoi(argv[++iarg]);
+ } else if (arg == "--no-ppl") {
+ compute_ppl = false;
} else {
args.push_back(argv[iarg]);
}
}
if (iarg < argc) {
- args.push_back(argv[iarg]);
+ std::string arg{argv[iarg]};
+ if (arg == "--no-ppl") {
+ compute_ppl = false;
+ } else {
+ args.push_back(argv[iarg]);
+ }
}
gpt_params params;
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
- bool OK = compute_imatrix(ctx, params);
+ bool OK = compute_imatrix(ctx, params, compute_ppl);
if (!OK) {
return 1;
}