]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add ability to evauate multiple choice tasks (#5047)
authorKawrakow <redacted>
Sun, 21 Jan 2024 12:42:44 +0000 (14:42 +0200)
committerGitHub <redacted>
Sun, 21 Jan 2024 12:42:44 +0000 (14:42 +0200)
* TruthfulQA: 1st attempt, does not look like it is working

The same implementation can be used for HellaSwag as well,
so I converted a HellaSwag validation dataset to the binary
format used here and tested with that. The score is only
around 50, so something is not quite right.

* TruthfulQA: works but the result is bad

I know it works because if I convert the HellaSwag validation
data to the binary format used in the truthful_qa_score() function
I get the exact same result as from the hellaswag_score() function.
But I guess, the questions are tricky and the way I have done
the combination of question + answer is very likely not the best.
The TruthfulQA validation dataset contains 817 questions, with
random chance result around 19%. With this version I get
29.1% for Mistral-7B and 55.2% for Mistral-7B-Instruct-v0.2.
The HF leader board results for these two models are
42.2% and 68.3%, respectively.

* TruthfulQA: fix random sample

* TruthfulQA: prepare tasks in parallel for large test datasets

* Rename truthful_qa to multiple_choice

* Make MSVC happy

I had forgotten that MSVC does not make constexpr's available
inside a lambda.

---------

Co-authored-by: Iwan Kawrakow <redacted>
common/common.cpp
common/common.h
examples/perplexity/perplexity.cpp

index ce20360a4f85b8df53bcb4b1f80d2c30cdcfebf6..0e4b8bab2ce657b510f7894d4e2cfdf1cdb52598 100644 (file)
@@ -203,6 +203,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             params.prompt_cache_all = true;
         } else if (arg == "--prompt-cache-ro") {
             params.prompt_cache_ro = true;
+        } else if (arg == "-bf" || arg == "--binary-file") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::ifstream file(argv[i], std::ios::binary);
+            if (!file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            // store the external file name in params
+            params.prompt_file = argv[i];
+            file.seekg(0, std::ios::end);
+            size_t size = file.tellg();
+            file.seekg(0, std::ios::beg);
+            params.prompt.resize(size);
+            file.read((char *)params.prompt.data(), size);
+            fprintf(stderr, "Read %zu bytes from binary file %s\n", size, argv[i]);
         } else if (arg == "-f" || arg == "--file") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -689,6 +708,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.winogrande_tasks = std::stoi(argv[i]);
+        } else if (arg == "--multiple-choice") {
+            params.multiple_choice = true;
+        } else if (arg == "--multiple-choice-tasks") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.multiple_choice_tasks = std::stoi(argv[i]);
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
         } else if (arg == "--no-penalize-nl") {
@@ -888,6 +915,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --in-suffix STRING    string to suffix after user inputs with (default: empty)\n");
     printf("  -f FNAME, --file FNAME\n");
     printf("                        prompt file to start generation.\n");
+    printf("  -bf FNAME, --binary-file FNAME\n");
+    printf("                        binary file containing multiple choice tasks.\n");
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
     printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -936,6 +965,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --hellaswag-tasks N   number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
     printf("  --winogrande          compute Winogrande score over random tasks from datafile supplied with -f\n");
     printf("  --winogrande-tasks N  number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
+    printf("  --multiple-choice     compute multiple choice score over random tasks from datafile supplied with -f\n");
+    printf("  --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
     printf("  --keep N              number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
     printf("  --draft N             number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
     printf("  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
index 0ae9c18b3114c4001a7f8c0ee8c814b85c45f40a..c69ad7e94898f29463126663da5ee8e656fd8531 100644 (file)
@@ -108,6 +108,9 @@ struct gpt_params {
     bool   winogrande      = false; // compute Winogrande score over random tasks from datafile supplied in prompt
     size_t winogrande_tasks= 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
 
+    bool   multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
+    size_t multiple_choice_tasks = 0;     // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
+
     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
index f91f5795a9851434723c9ea53886353441be7a33..b7ef9a0843fbd25653190cbed699b8d0d9c0b996 100644 (file)
@@ -540,14 +540,14 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     // This is needed as usual for LLaMA models
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
+    // The tasks should be randomized so the score stabilizes quickly.
+    bool randomize_tasks = true;
+
     // Number of tasks to use when computing the score
     if (params.hellaswag_tasks < hs_task_count) {
         hs_task_count = params.hellaswag_tasks;
     }
 
-    // The tasks should be randomized so the score stabilizes quickly.
-    bool randomize_tasks = true;
-
     // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
     std::mt19937 rng(1);
 
@@ -1031,6 +1031,389 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
     printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
 }
 
+static bool deserialize_string(std::istream& in, std::string& str) {
+    uint32_t size;
+    if (!in.read((char *)&size, sizeof(size)).fail()) {
+        str.resize(size);
+        if (!in.read((char *)str.data(), size).fail()) return true;
+    }
+    return false;
+}
+
+struct multiple_choice_answers {
+    std::vector<std::string> answers;
+    std::vector<int>         labels;
+    bool deserialize(std::istream& in) {
+        uint32_t n;
+        in.read((char *)&n, sizeof(n));
+        if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose
+        answers.resize(n);
+        labels.resize(n);
+        for (auto& a : answers) {
+            if (!deserialize_string(in, a)) return false;
+        }
+        in.read((char *)labels.data(), n*sizeof(int));
+        return !in.fail();
+    }
+};
+
+struct multiple_choice_task {
+    std::string question;         // the question (or context that needs to be continued)
+    multiple_choice_answers mc1;  // possible answers (continuations) with a single correct answer
+    multiple_choice_answers mc2;  // possible answers (continuations) with multiple correct answers - not handled yet
+    bool deserialize(std::istream& in) {
+        if (!deserialize_string(in, question)) return false;
+        return mc1.deserialize(in) && mc2.deserialize(in);
+    }
+
+    // For evaluation
+    size_t i_batch;         // starting index in the llama_batch
+    size_t common_prefix;   // max number of initial tokens that are the same in all sentences
+    size_t required_tokens; // needed number of tokens to evaluate all answers
+    std::vector<std::vector<llama_token>> seq_tokens;
+    std::vector<float> log_probs;
+};
+
+static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) {
+    if (task.question.empty() || task.mc1.answers.empty()) {
+        if (log_error) {
+            printf("%s: found bad task with empty question and/or answers\n", __func__);
+        }
+        return false;
+    }
+    task.seq_tokens.reserve(task.mc1.answers.size());
+    for (auto& answer : task.mc1.answers) {
+        if (answer.empty()) {
+            if (log_error) {
+                printf("%s: found empty answer\n", __func__);
+            }
+            return false;
+        }
+        task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
+    }
+    auto min_len = task.seq_tokens.front().size();
+    for (auto& seq : task.seq_tokens) {
+        min_len = std::min(min_len, seq.size());
+    }
+    task.common_prefix = 0;
+    for (size_t k = 0; k < min_len; ++k) {
+        auto token = task.seq_tokens[0][k];
+        bool all_same = true;
+        for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
+            if (task.seq_tokens[i][k] != token) {
+                all_same = false;
+                break;
+            }
+        }
+        if (!all_same) {
+            break;
+        }
+        ++task.common_prefix;
+    }
+    task.required_tokens = task.common_prefix;
+    for (auto& seq : task.seq_tokens) {
+        task.required_tokens += seq.size() - task.common_prefix;
+    }
+    return true;
+}
+
+//
+// Calculates score for multiple choice tasks with single correct answer from prompt.
+// Commonly used LLM evaluation metrics of this type are
+//   * ARC
+//   * HellaSwag
+//   * MMLU
+//   * TruthfulQA
+//
+// Validation datasets for these 4 tests can be found at
+//     https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
+// The data for these datasets was extracted from
+//     git@hf.co:datasets/allenai/ai2_arc
+//     https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
+//     git@hf.co:datasets/Stevross/mmlu
+//     https://huggingface.co/datasets/truthful_qa
+//
+static void multiple_choice_score(llama_context * ctx, const gpt_params & params) {
+
+    std::istringstream strstream(params.prompt);
+    uint32_t n_task;
+    strstream.read((char *)&n_task, sizeof(n_task));
+    if (strstream.fail() || n_task == 0) {
+        printf("%s: no tasks\n", __func__);
+        return;
+    }
+    printf("%s: there are %u tasks in prompt\n", __func__, n_task);
+    std::vector<uint32_t> task_pos(n_task);
+    strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
+    if (strstream.fail()) {
+        printf("%s: failed to raad task positions from prompt\n", __func__);
+        return;
+    }
+
+    std::vector<multiple_choice_task> tasks;
+    if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
+        // Use all tasks
+        tasks.resize(n_task);
+        printf("%s: reading tasks", __func__);
+        int n_dot = n_task/100;
+        int i = 0;
+        for (auto& task : tasks) {
+            ++i;
+            if (!task.deserialize(strstream)) {
+                printf("%s: failed to read task %d of %u\n", __func__, i, n_task);
+                return;
+            }
+            if (i%n_dot == 0) printf(".");
+        }
+        printf("done\n");
+    }
+    else {
+        printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
+        std::mt19937 rng(1);
+        std::vector<int> aux(n_task);
+        for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
+        float scale = 1.f/(1.f + (float)std::mt19937::max());
+        tasks.resize(params.multiple_choice_tasks);
+        for (auto& task : tasks) {
+            int j = (int)(scale * rng() * aux.size());
+            int idx = aux[j];
+            aux[j] = aux.back();
+            aux.pop_back();
+            strstream.seekg(task_pos[idx], std::ios::beg);
+            if (!task.deserialize(strstream)) {
+                printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
+                return;
+            }
+        }
+        n_task = params.multiple_choice_tasks;
+    }
+
+    // This is needed as usual for LLaMA models
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+
+    printf("%s: preparing task data", __func__);
+    fflush(stdout);
+    if (n_task > 500) {
+        printf("...");
+        fflush(stdout);
+        std::atomic<int> counter(0);
+        std::atomic<int> n_bad(0);
+        auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
+            int num_tasks = tasks.size();
+            int n_bad_local = 0;
+            while (true) {
+                int first = counter.fetch_add(K_TOKEN_CHUNK);
+                if (first >= num_tasks) {
+                    if (n_bad_local > 0) n_bad += n_bad_local;
+                    break;
+                }
+                int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
+                for (int i = first; i < last; ++i) {
+                    if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
+                }
+            }
+        };
+        size_t max_thread = std::thread::hardware_concurrency();
+        max_thread = std::min(max_thread, (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK);
+        std::vector<std::thread> workers(max_thread-1);
+        for (auto& w : workers) w = std::thread(prepare);
+        prepare();
+        for (auto& w : workers) w.join();
+        printf("done\n");
+        fflush(stdout);
+        int nbad = n_bad;
+        if (nbad > 0) {
+            printf("%s: found %d malformed tasks\n", __func__, nbad);
+            return;
+        }
+    } else {
+        int n_dot = n_task/100;
+        int i_task = 0;
+        for (auto& task : tasks) {
+            ++i_task;
+            if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) {
+                return;
+            }
+            if (i_task%n_dot == 0) {
+                printf(".");
+                fflush(stdout);
+            }
+        }
+        printf("done\n");
+    }
+
+    printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
+
+    printf("\ntask\tacc_norm\n");
+
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+    const int n_ctx   = llama_n_ctx(ctx);
+    const int n_batch = params.n_batch;
+
+    const int max_tasks_per_batch = 32;
+    const int max_seq = 4*max_tasks_per_batch;
+
+    llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
+
+    std::vector<float> tok_logits(n_vocab);
+    std::vector<float> batch_logits(n_vocab*n_ctx);
+
+    std::vector<std::pair<size_t, llama_token>> eval_pairs;
+    std::vector<float> eval_results;
+    std::vector<std::thread> workers(std::thread::hardware_concurrency());
+    std::vector<int> batch_indeces;
+
+    int n_done = 0;
+    int n_correct = 0;
+    int n_tot_answers = 0;
+
+    for (size_t i0 = 0; i0 < tasks.size(); i0++) {
+        int n_cur = 0;
+
+        size_t i1 = i0;
+        size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+
+        llama_batch_clear(batch);
+
+        // batch as much tasks as possible into the available context
+        // each task has 4 unique seuqnce ids - one for each ending
+        // the common prefix is shared among the 4 sequences to save tokens
+        // we extract logits only from the last common token and from all ending tokens of each sequence
+        int s0 = 0;
+        while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
+            auto& cur_task = tasks[i1];
+
+            int num_answers = cur_task.seq_tokens.size();
+            if (s0 + num_answers > max_seq) {
+                break;
+            }
+
+            if (int(batch_indeces.size()) != num_answers) {
+                batch_indeces.resize(num_answers);
+            }
+            for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
+
+            for (size_t i = 0; i < cur_task.common_prefix; ++i) {
+                //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
+                llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
+            }
+            batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
+
+            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
+                for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
+                    llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
+                }
+            }
+
+            s0 += num_answers;
+
+            cur_task.i_batch = i_batch;
+            i_batch += cur_task.required_tokens;
+
+            n_cur += cur_task.required_tokens;
+            if (++i1 == tasks.size()) {
+                break;
+            }
+        }
+
+        if (i0 == i1) {
+            fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
+            return;
+        }
+
+        llama_kv_cache_clear(ctx);
+
+        // decode all tasks [i0, i1)
+        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
+            fprintf(stderr, "%s: llama_decode() failed\n", __func__);
+            return;
+        }
+
+        // Compute log-probs in parallel
+        // First we collect all tasks
+        eval_pairs.clear();
+        for (size_t i = i0; i < i1; ++i) {
+            auto& cur_task = tasks[i];
+            size_t li = cur_task.common_prefix;
+            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
+                for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
+                    eval_pairs.push_back(std::make_pair(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]));
+                }
+                ++li;
+            }
+        }
+        // Then we do the actual calculation
+        compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
+
+        size_t ir = 0;
+
+        // compute the logprobs for each ending of the decoded tasks
+        for (size_t i = i0; i < i1; ++i) {
+            auto & cur_task = tasks[i];
+            //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
+            //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
+            //    if (cur_task.mc1.labels[j] == 1) {
+            //        printf("%d", j+1);
+            //    }
+            //}
+            //printf("\n    common_prefix: %zu\n", cur_task.common_prefix);
+
+            std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
+
+            const auto first_probs = softmax(tok_logits);
+
+            cur_task.log_probs.resize(cur_task.seq_tokens.size());
+            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
+                size_t count = 1;
+                float  log_prob  = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
+                for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
+                    //printf("        %zu  %g\n", ir, eval_results[ir]);
+                    ++count;
+                    log_prob += eval_results[ir++];
+                }
+                cur_task.log_probs[s] = log_prob / count;
+                //printf("        Final: %g\n", log_prob / count);
+                //printf("    <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
+            }
+
+            // Find the ending with maximum logprob
+            size_t logprob_max_idx = 0;
+            float  logprob_max_val = cur_task.log_probs[0];
+            for (size_t s = 1; s < cur_task.log_probs.size(); s++) {
+                if (cur_task.log_probs[s] > logprob_max_val) {
+                    logprob_max_val = cur_task.log_probs[s];
+                    logprob_max_idx = s;
+                }
+            }
+
+            n_tot_answers += cur_task.log_probs.size();
+            if (cur_task.mc1.labels[logprob_max_idx] == 1) {
+                ++n_correct;
+            }
+            ++n_done;
+
+            // Print the accumulated accuracy mean x 100
+            printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
+            fflush(stdout);
+        }
+
+        i0 = i1 - 1;
+    }
+
+    llama_batch_free(batch);
+
+    if (n_done < 100) return;
+
+    float p = 1.f*n_correct/n_done;
+    float sigma = sqrt(p*(1-p)/(n_done-1));
+    printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
+    p = 1.f*n_done/n_tot_answers;
+    sigma = sqrt(p*(1-p)/(n_done-1));
+    printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
+
+    printf("\n");
+}
+
 
 int main(int argc, char ** argv) {
     gpt_params params;
@@ -1091,6 +1474,8 @@ int main(int argc, char ** argv) {
         hellaswag_score(ctx, params);
     } else if (params.winogrande) {
         winogrande_score(ctx, params);
+    } else if (params.multiple_choice) {
+        multiple_choice_score(ctx, params);
     } else {
         results = perplexity(ctx, params);
     }