#include <cmath>
#include <ctime>
#include <sstream>
+#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);
+ std::vector<float> tok_logits(n_vocab);
+
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
// Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
size_t context_size = context_embd.size();
- for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
+ // Do the 1st ending
+ // In this case we include the context when evaluating
+ auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
+ auto query_size = query_embd.size();
+ //printf("First query: %d\n",(int)query_size);
+
+ // Stop if query wont fit the ctx window
+ if (query_size > (size_t)params.n_ctx) {
+ fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
+ return;
+ }
+
+ // Speedup small evaluations by evaluating atleast 32 tokens
+ if (query_size < 32) {
+ query_embd.resize(32);
+ }
+
+ // Evaluate the query
+ if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return;
+ }
+
+ auto query_logits = llama_get_logits(ctx);
+
+ std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
+ const auto first_probs = softmax(tok_logits);
+
+ hs_data[task_idx].ending_logprob_count[0] = 1;
+ hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
+
+ // Calculate the logprobs over the ending
+ for (size_t j = context_size; j < query_size - 1; j++) {
+
+ std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
+
+ const float prob = softmax(tok_logits)[query_embd[j + 1]];
+
+ hs_data[task_idx].ending_logprob[0] += std::log(prob);
+ hs_data[task_idx].ending_logprob_count[0]++;
+ }
+
+ // Calculate the mean token logprob for acc_norm
+ hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
+
+ // Do the remaining endings
+ // For these, we use the bare ending with n_past = context_size
+ //
+ for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
// Tokenize the query
- std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
- size_t query_size = query_embd.size();
+ query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
+ query_size = query_embd.size();
+ //printf("Second query: %d\n",(int)query_size);
// Stop if query wont fit the ctx window
- if (query_size > (size_t)params.n_ctx) {
+ if (context_size + query_size > (size_t)params.n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// Speedup small evaluations by evaluating atleast 32 tokens
- if (query_size < 32) {
- query_embd.resize(32);
- }
+ // No, resizing to 32 is actually slightly slower (at least on CUDA)
+ //if (query_size < 32) {
+ // query_embd.resize(32);
+ //}
// Evaluate the query
- if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+ if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
- const auto query_logits = llama_get_logits(ctx);
- std::vector<float> logits;
- logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
+ query_logits = llama_get_logits(ctx);
- hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
- hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
+ hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
+ hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
// Calculate the logprobs over the ending
- for (size_t j = context_size-1; j < query_size - 1; j++) {
- // Calculate probability of next token, given the previous ones.
- const std::vector<float> tok_logits(
- logits.begin() + (j + 0) * n_vocab,
- logits.begin() + (j + 1) * n_vocab);
+ for (size_t j = 0; j < query_size - 1; j++) {
+ std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
- const float prob = softmax(tok_logits)[query_embd[ j + 1]];
+ const float prob = softmax(tok_logits)[query_embd[j + 1]];
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
hs_data[task_idx].ending_logprob_count[ending_idx]++;
}
// Find the ending with maximum logprob
- size_t ending_logprob_max_idx = -1;
- double ending_logprob_max_val = -INFINITY;
- for (size_t j=0; j < 4; j++) {
+ size_t ending_logprob_max_idx = 0;
+ double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
+ for (size_t j = 1; j < 4; j++) {
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
ending_logprob_max_idx = j;
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];