]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : add Hellaswag calculation (#2389)
authorklosax <redacted>
Fri, 28 Jul 2023 18:25:36 +0000 (20:25 +0200)
committerGitHub <redacted>
Fri, 28 Jul 2023 18:25:36 +0000 (21:25 +0300)
* common.h : add hellaswag / remove perplexity-lines

* common.cpp : add hellaswag / remove perplexity-lines

* perplexity.cpp : add hellswag scores / remove perplexity-lines

* perplexity.cpp : clean up

* common.h : change default param value

* common.cpp : Change default param

* perplexity.cpp : alter wording

* common.h : alter wording

* common.cpp : alter wording

examples/common.cpp
examples/common.h
examples/perplexity/perplexity.cpp

index dd964c8a7481a32808781f8b10d96d8c16c124d8..fe7308b1787eb9377fea2818902f1180c7392f1a 100644 (file)
@@ -402,8 +402,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.antiprompt.push_back(argv[i]);
         } else if (arg == "--perplexity") {
             params.perplexity = true;
-        } else if (arg == "--perplexity-lines") {
-            params.perplexity_lines = true;
+        } else if (arg == "--hellaswag") {
+            params.hellaswag = true;
+        } else if (arg == "--hellaswag-tasks") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.hellaswag_tasks = std::stoi(argv[i]);
         } else if (arg == "--ignore-eos") {
             params.logit_bias[llama_token_eos()] = -INFINITY;
         } else if (arg == "--no-penalize-nl") {
@@ -559,8 +565,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "                        not recommended: doubles context memory required and no measurable increase in quality\n");
     fprintf(stdout, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
     fprintf(stdout, "  --perplexity          compute perplexity over each ctx window of the prompt\n");
-    fprintf(stdout, "  --perplexity-lines    compute perplexity over each line of the prompt\n");
-    fprintf(stdout, "  --keep                number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
+    fprintf(stdout, "  --hellaswag           compute HellaSwag score over random tasks from datafile supplied with -f\n");
+    fprintf(stdout, "  --hellaswag-tasks N   number of tasks to use when computing the HellaSwag score (default: %d)\n", params.hellaswag_tasks);
+    fprintf(stdout, "  --keep N              number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
     fprintf(stdout, "  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
     if (llama_mlock_supported()) {
         fprintf(stdout, "  --mlock               force system to keep model in RAM rather than swapping or compressing\n");
index 672dcf77c8f992166ebeba2b5a7e630ac59172b9..1184f32df50742040d38d5e7164d33ff8d198c26 100644 (file)
@@ -70,7 +70,10 @@ struct gpt_params {
     std::string lora_adapter = "";  // lora adapter path
     std::string lora_base    = "";  // base model path for the lora adapter
 
-    bool low_vram          = false;   // if true, reduce VRAM usage at the cost of performance
+    bool hellaswag         = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+    size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
+
+    bool low_vram          = false; // if true, reduce VRAM usage at the cost of performance
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
@@ -86,7 +89,6 @@ struct gpt_params {
     bool instruct          = false; // instruction mode (used for Alpaca models)
     bool penalize_nl       = true;  // consider newlines as a repeatable token
     bool perplexity        = false; // compute perplexity over the prompt
-    bool perplexity_lines  = false; // compute perplexity over each line of the prompt
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
     bool mem_test          = false; // compute maximum memory usage
index d23b7e7f0c1b8f2ea693178c5a16466d5df081f7..6870a11b931dba4941026e05f4c2279049440b2c 100644 (file)
@@ -121,8 +121,23 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     printf("\n");
 }
 
-void perplexity_lines(llama_context * ctx, const gpt_params & params) {
-    // Calculates perplexity over each line of the prompt
+void hellaswag_score(llama_context * ctx, const gpt_params & params) {
+    // Calculates hellaswag score (acc_norm) from prompt
+    //
+    // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
+    // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
+    //
+    // All 10042 tasks should be extracted to keep the results standardized like other implementations.
+    //
+    // Datafile layout:
+    // ['??'] denotes json fields
+    // 6 lines per task:
+    // ['activity_label'] + ": " +['ctx']  - The first part of the query, the context
+    // ['label'] - The index the best common sense ending aka gold ending
+    // ['endings'][0] - Endings added to the first part of the query
+    // ['endings'][1]
+    // ['endings'][2]
+    // ['endings'][3]
 
     std::vector<std::string> prompt_lines;
     std::istringstream strstream(params.prompt);
@@ -132,63 +147,149 @@ void perplexity_lines(llama_context * ctx, const gpt_params & params) {
         prompt_lines.push_back(line);
     }
 
-    const int n_vocab = llama_n_vocab(ctx);
+    if( prompt_lines.size() % 6 != 0) {
+        fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
+        return;
+    }
 
-    int counttotal   = 0;
-    size_t n_lines = prompt_lines.size();
+    size_t hs_task_count = prompt_lines.size()/6;
+    fprintf(stderr, "%s : loaded %lu tasks from prompt.\n", __func__, hs_task_count);
 
-    double nll = 0.0;
+    // This is needed as usual for LLaMA models
+    bool prepend_bos = true;
+
+    // Number of tasks to use when computing the score
+    if ( params.hellaswag_tasks < hs_task_count  ) {
+        hs_task_count = params.hellaswag_tasks;
+    }
 
-    fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
+    // The tasks should be randomized so the score stabilizes quickly.
+    bool randomize_tasks = true;
 
-    printf("\nLine\tPPL line\tPPL cumulative\n");
+    // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
+    std::mt19937 rng(1);
 
-    for (size_t i = 0; i < n_lines; ++i) {
+    // Dataholder for hellaswag tasks
+    struct hs_data_t {
+        std::string context;
+        size_t gold_ending_idx;
+        std::string ending[4];
+        size_t ending_logprob_count[4];
+        double ending_logprob[4];
+    };
 
-        // Tokenize and insert BOS at start
-        std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
+    fprintf(stderr, "%s : selecting %lu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first")  );
 
-        size_t batch_size  = batch_embd.size();
+    // Select and read data from prompt lines
+    hs_data_t *hs_data = new hs_data_t[hs_task_count];
+    for (size_t i=0; i < hs_task_count; i++) {
+        size_t idx = i;
 
-        // Stop if line is too long
-        if( batch_size > (size_t)params.n_ctx ) {
-            fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
-            return;
+        // Select a random example of those left in the prompt
+        if (randomize_tasks) {
+            std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
+            idx = dist(rng);
         }
 
-        if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
-            return;
+        hs_data[i].context = prompt_lines[idx*6];
+        hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
+        for (size_t j=0; j < 4; j++) {
+            hs_data[i].ending[j] = " " + prompt_lines[idx*6+2+j];
         }
 
-        const auto batch_logits = llama_get_logits(ctx);
-        std::vector<float> logits;
-        logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+        // Delete the selected random example from the prompt
+        if (randomize_tasks) {
+            prompt_lines.erase( std::next(prompt_lines.begin(),idx*6)  , std::next(prompt_lines.begin(),idx*6+6) );
+        }
+    }
 
-        double nllline = 0.0;
-        int countline = 0;
+    fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
+    printf("\ntask\tacc_norm\n");
 
-        // Perplexity over second half of the line
-        for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
-            // Calculate probability of next token, given the previous ones.
-            const std::vector<float> tok_logits(
-                logits.begin() + (j + 0) * n_vocab,
-                logits.begin() + (j + 1) * n_vocab);
+    double acc = 0.0f;
+    const int n_vocab = llama_n_vocab(ctx);
+
+    for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
+
+        // Tokenize the context to count tokens
+        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
+        size_t context_size = context_embd.size();
+
+        for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
+
+            // Tokenize the query
+            std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
+            size_t query_size = query_embd.size();
+
+            // Stop if query wont fit the ctx window
+            if (query_size > (size_t)params.n_ctx) {
+                fprintf(stderr, "%s : number of tokens in query %lu > n_ctxl\n", __func__, query_size);
+                return;
+            }
 
-            const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
+            // Speedup small evaluations by evaluating atleast 32 tokens
+            if (query_size < 32) {
+                query_embd.resize(32);
+            }
+
+            // Evaluate the query
+            if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
+                fprintf(stderr, "%s : failed to eval\n", __func__);
+                return;
+            }
+
+            const auto query_logits = llama_get_logits(ctx);
+            std::vector<float> logits;
+            logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
+
+            hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
+            hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
+
+            // Calculate the logprobs over the ending
+            for (size_t j = context_size-1; j < query_size - 1; j++) {
+                // Calculate probability of next token, given the previous ones.
+                const std::vector<float> tok_logits(
+                    logits.begin() + (j + 0) * n_vocab,
+                    logits.begin() + (j + 1) * n_vocab);
+
+                const float prob = softmax(tok_logits)[query_embd[ j + 1]];
+
+                hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
+                hs_data[task_idx].ending_logprob_count[ending_idx]++;
+            }
+
+            // Calculate the mean token logprob for acc_norm
+            hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
+
+
+//            printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
+//                task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
+        }
 
-            nllline += -std::log(prob);
-            ++countline;
+        // Find the ending with maximum logprob
+        size_t ending_logprob_max_idx = -1;
+        double ending_logprob_max_val = -INFINITY;
+        for (size_t j=0; j < 4; j++) {
+            if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
+                ending_logprob_max_idx = j;
+                ending_logprob_max_val =  hs_data[task_idx].ending_logprob[j];
+            }
         }
 
-        nll += nllline;
-        counttotal += countline;
+//        printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
 
-        // perplexity is e^(average negative log-likelihood)
-        printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
+        // If the gold ending got the maximum logprobe add one accuracy point
+        if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
+            acc += 1.0;
+        }
+
+        // Print the accumulated accuracy mean x 100
+        printf("%li\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
         fflush(stdout);
     }
 
+    delete [] hs_data;
+
     printf("\n");
 }
 
@@ -240,8 +341,8 @@ int main(int argc, char ** argv) {
                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
     }
 
-    if (params.perplexity_lines) {
-        perplexity_lines(ctx, params);
+    if (params.hellaswag) {
+        hellaswag_score(ctx, params);
     } else {
         perplexity(ctx, params);
     }