return probs;
}
+void perplexity_v2(llama_context * ctx, const gpt_params & params) {
+
+ // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
+ // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
+ // Output: `perplexity: 13.5106 [114/114]`
+ // BOS tokens will be added for each chunk before eval
+
+ if (params.ppl_stride <= 0) {
+ fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
+ return;
+ }
+ auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+ const int calc_chunk = params.n_ctx;
+
+ fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
+
+ if (int(tokens.size()) <= calc_chunk) {
+ fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
+ tokens.size(), params.n_ctx, params.ppl_stride);
+ return;
+ }
+
+ const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
+
+ const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
+ const int n_vocab = llama_n_vocab(ctx);
+ const int n_batch = params.n_batch;
+
+ int count = 0;
+ double nll = 0.0;
+
+ fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+
+ for (int i = 0; i < n_chunk; ++i) {
+ const int start = i * params.ppl_stride;
+ const int end = start + calc_chunk;
+
+ const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
+ //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
+
+ std::vector<float> logits;
+
+ const auto t_start = std::chrono::high_resolution_clock::now();
+
+ for (int j = 0; j < num_batches; ++j) {
+ const int batch_start = start + j * n_batch;
+ const int batch_size = std::min(end - batch_start, n_batch);
+
+ //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
+ if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
+ //fprintf(stderr, "%s : failed to eval\n", __func__);
+ return;
+ }
+
+ // save original token and restore it after eval
+ const auto token_org = tokens[batch_start];
+
+ // add BOS token for the first batch of each chunk
+ if (j == 0) {
+ tokens[batch_start] = llama_token_bos(ctx);
+ }
+
+ const auto batch_logits = llama_get_logits(ctx);
+ logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+
+ if (j == 0) {
+ tokens[batch_start] = token_org;
+ }
+ }
+
+ const auto t_end = std::chrono::high_resolution_clock::now();
+
+ if (i == 0) {
+ const float t_total = std::chrono::duration<float>(t_end - t_start).count();
+ fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
+ int total_seconds = (int)(t_total * n_chunk);
+ if (total_seconds >= 60*60) {
+ fprintf(stderr, "%d hours ", total_seconds / (60*60));
+ total_seconds = total_seconds % (60*60);
+ }
+ fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
+ }
+
+ //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
+ for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
+
+ // Calculate probability of next token, given the previous ones.
+ const std::vector<float> tok_logits(
+ logits.begin() + (j + 0) * n_vocab,
+ logits.begin() + (j + 1) * n_vocab);
+
+ const float prob = softmax(tok_logits)[tokens[start + j + 1]];
+
+ nll += -std::log(prob);
+ ++count;
+ }
+ // perplexity is e^(average negative log-likelihood)
+ if (params.ppl_output_type == 0) {
+ printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+ } else {
+ printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
+ }
+ fflush(stdout);
+ }
+ printf("\n");
+}
+
void perplexity(llama_context * ctx, const gpt_params & params) {
+
+ if (params.ppl_stride > 0) {
+ perplexity_v2(ctx, params);
+ return;
+ }
+
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
++count;
}
// perplexity is e^(average negative log-likelihood)
- printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+ if (params.ppl_output_type == 0) {
+ printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
+ } else {
+ printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count));
+ }
fflush(stdout);
}
printf("\n");
params.perplexity = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
+ if (params.ppl_stride > 0) {
+ fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
+ params.n_ctx, params.n_ctx + params.ppl_stride/2);
+ params.n_ctx += params.ppl_stride/2;
+ }
+
if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);