return {tokens, std::exp(nll / count), logit_history, prob_history};
}
-static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
+static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
if (params.ppl_stride > 0) {
return perplexity_v2(ctx, params);
}
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
- const int n_ctx = llama_n_ctx(ctx);
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+ const int n_seq = std::max(1, n_batch / n_ctx);
+
+ GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
+ GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
+
+ llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}
- fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+ fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
log_probs.resize(n_ctx * nv);
}
- for (int i = 0; i < n_chunk; ++i) {
+ // We get the logits for all the tokens in the context window (params.n_ctx)
+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
+ // calculate the perplexity over the last half of the window (so the model always has
+ // some context to predict the token).
+ //
+ // We rely on the fact that attention in the forward pass only looks at previous
+ // tokens here, so the logits returned for each token are an accurate representation
+ // of what the model would have predicted at that point.
+ //
+ // Example, we have a context window of 512, we will compute perplexity for each of the
+ // last 256 tokens. Then, we split the input up into context window size chunks to
+ // process the entire prompt.
+ const int first = n_ctx/2;
+
+ for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx;
const int end = start + n_ctx;
+ const int n_seq_batch = std::min(n_seq, n_chunk - i);
+
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
- // save original token and restore it after eval
- const auto token_org = tokens[batch_start];
+ batch.n_tokens = 0;
+ for (int seq = 0; seq < n_seq_batch; seq++) {
+ int seq_start = batch_start + seq*n_ctx;
- // add BOS token for the first batch of each chunk
- if (add_bos && j == 0) {
- tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
+ // save original token and restore it after eval
+ const auto token_org = tokens[seq_start];
+
+ // add BOS token for the first batch of each chunk
+ if (add_bos && j == 0) {
+ tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
+ }
+
+ for (int k = 0; k < batch_size; ++k) {
+ const int idx = seq*n_ctx + k;
+ batch.token[idx] = tokens[seq_start + k];
+ batch.pos[idx] = j*n_batch + k;
+ batch.n_seq_id[idx] = 1;
+ batch.seq_id[idx][0] = seq;
+ batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
+ }
+ batch.n_tokens += batch_size;
+
+ // restore the original token in case it was set to BOS
+ tokens[seq_start] = token_org;
}
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
- // restore the original token in case it was set to BOS
- tokens[batch_start] = token_org;
-
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
- int total_seconds = (int)(t_total * n_chunk);
+ int total_seconds = (int)(t_total*n_chunk/n_seq);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
- // We get the logits for all the tokens in the context window (params.n_ctx)
- // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
- // calculate the perplexity over the last half of the window (so the model always has
- // some context to predict the token).
- //
- // We rely on the fact that attention in the forward pass only looks at previous
- // tokens here, so the logits returned for each token are an accurate representation
- // of what the model would have predicted at that point.
- //
- // Example, we have a context window of 512, we will compute perplexity for each of the
- // last 256 tokens. Then, we split the input up into context window size chunks to
- // process the entire prompt.
- const int first = n_ctx/2;
- const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
- if (!params.logits_file.empty()) {
- process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
- workers, log_probs, nll, nll2);
- } else {
- process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
- workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
- }
- count += n_ctx - first - 1;
-
- // perplexity is e^(average negative log-likelihood)
- if (params.ppl_output_type == 0) {
- printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
- } else {
- double av = nll/count;
- double av2 = nll2/count - av*av;
- if (av2 > 0) av2 = sqrt(av2/(count-1));
- printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
+ for (int seq = 0; seq < n_seq_batch; seq++) {
+ const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
+ llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
+ if (!params.logits_file.empty()) {
+ process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
+ tokens_data, n_ctx - 1 - first,
+ workers, log_probs, nll, nll2);
+ } else {
+ process_logits(n_vocab, all_logits + first*n_vocab,
+ tokens_data, n_ctx - 1 - first,
+ workers, nll, nll2,
+ logit_history.data() + start + seq*n_ctx + first,
+ prob_history.data() + start + seq*n_ctx + first);
+ }
+ count += n_ctx - first - 1;
+
+ // perplexity is e^(average negative log-likelihood)
+ if (params.ppl_output_type == 0) {
+ printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
+ } else {
+ double av = nll/count;
+ double av2 = nll2/count - av*av;
+ if (av2 > 0) av2 = sqrt(av2/(count-1));
+ printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
+ }
}
fflush(stdout);
printf("Unexpected negative standard deviation of log(prob)\n");
}
+ llama_batch_free(batch);
+
return {tokens, ppl, logit_history, prob_history};
}
int main(int argc, char ** argv) {
gpt_params params;
- params.n_batch = 512;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
params.logits_all = true;
- params.n_batch = std::min(params.n_batch, params.n_ctx);
+
+ const int32_t n_ctx = params.n_ctx;
+
+ const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
+ if (ppl) {
+ int n_seq = std::max(1, params.n_batch / n_ctx);
+ int32_t n_kv = n_seq * n_ctx;
+ params.n_parallel = n_seq;
+ params.n_ctx = n_kv;
+ params.n_batch = std::min(params.n_batch, n_kv);
+ } else {
+ params.n_batch = std::min(params.n_batch, params.n_ctx);
+ }
if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
} else if (params.kl_divergence) {
kl_divergence(ctx, params);
} else {
- results = perplexity(ctx, params);
+ results = perplexity(ctx, params, n_ctx);
}
llama_print_timings(ctx);