]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
KL-divergence (#5076)
authorKawrakow <redacted>
Mon, 22 Jan 2024 14:10:14 +0000 (16:10 +0200)
committerGitHub <redacted>
Mon, 22 Jan 2024 14:10:14 +0000 (16:10 +0200)
* kl-divergence: be able to save all logits to a file

* Add ability to compute KL-divergence

---------

Co-authored-by: Iwan Kawrakow <redacted>
common/common.cpp
common/common.h
examples/perplexity/perplexity.cpp

index 0e4b8bab2ce657b510f7894d4e2cfdf1cdb52598..0a7096171f2b56e4cefe05c1db1b4745ac1a7b56 100644 (file)
@@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             if (params.logdir.back() != DIRECTORY_SEPARATOR) {
                 params.logdir += DIRECTORY_SEPARATOR;
             }
+        } else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.logits_file = argv[i];
         } else if (arg == "--perplexity" || arg == "--all-logits") {
             params.logits_all = true;
         } else if (arg == "--ppl-stride") {
@@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.multiple_choice_tasks = std::stoi(argv[i]);
+        } else if (arg == "--kl-divergence") {
+            params.kl_divergence = true;
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
         } else if (arg == "--no-penalize-nl") {
@@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --winogrande-tasks N  number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
     printf("  --multiple-choice     compute multiple choice score over random tasks from datafile supplied with -f\n");
     printf("  --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
+    printf("  --kl-divergence       computes KL-divergence to logits provided via --kl-divergence-base");
     printf("  --keep N              number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
     printf("  --draft N             number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
     printf("  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
index c69ad7e94898f29463126663da5ee8e656fd8531..214a379b57d1ba7f8f5cf1de0cbecfcc00c7003f 100644 (file)
@@ -91,6 +91,7 @@ struct gpt_params {
     std::string input_suffix      = "";  // string to suffix user inputs with
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
     std::string logdir            = "";  // directory in which to save YAML log files
+    std::string logits_file       = "";  // file for saving *all* logits
 
     std::vector<llama_model_kv_override> kv_overrides;
 
@@ -111,6 +112,8 @@ struct gpt_params {
     bool   multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
     size_t multiple_choice_tasks = 0;     // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
 
+    bool   kl_divergence   = false; // compute KL-divergence
+
     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
index b7ef9a0843fbd25653190cbed699b8d0d9c0b996..1b7f85f4985630ade477be4e70892faaf5d3d33b 100644 (file)
@@ -112,6 +112,43 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to
     return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
 }
 
+static inline int nearest_int(float fval) {
+    //assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
+    float max_logit = logits[0];
+    float min_logit = logits[0];
+    for (int i = 1; i < n_vocab; ++i) {
+        max_logit = std::max(max_logit, logits[i]);
+        min_logit = std::min(min_logit, logits[i]);
+    }
+    min_logit = std::max(min_logit, max_logit - 16);
+    double sum_exp = 0.0;
+    for (int i = 0; i < n_vocab; ++i) {
+        sum_exp += expf(logits[i] - max_logit);
+    }
+    const float log_sum_exp = log(sum_exp);
+    const float min_log_prob = min_logit - max_logit - log_sum_exp;
+    const float scale = (max_logit - min_logit)/65535.f;
+    float * d = (float *)log_prob;
+    d[0] = scale;
+    d[1] = min_log_prob;
+    log_prob += 4;
+    if (scale) {
+        const float inv_scale = 1/scale;
+        for (int i = 0; i < n_vocab; ++i) {
+            log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
+        }
+    } else {
+        std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
+    }
+    return max_logit + log_sum_exp - logits[tok];
+}
+
 static void process_logits(
     int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
     double & nll, double & nll2, float * logit_history, float * prob_history
@@ -147,6 +184,114 @@ static void process_logits(
     }
 }
 
+static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
+        std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
+    std::mutex mutex;
+    const int nv = 2*((n_vocab + 1)/2) + 4;
+    int counter = 0;
+    auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
+        double local_nll  = 0;
+        double local_nll2 = 0;
+        while (true) {
+            std::unique_lock<std::mutex> lock(mutex);
+            int i = counter++;
+            if (i >= n_token) {
+                nll += local_nll; nll2 += local_nll2;
+                break;
+            }
+            lock.unlock();
+            const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
+            local_nll += v;
+            local_nll2 += v*v;
+        }
+    };
+    for (auto & w : workers) {
+        w = std::thread(compute);
+    }
+    compute();
+    for (auto & w : workers) {
+        w.join();
+    }
+    out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
+}
+
+struct kl_divergence_result {
+    double sum_nll  = 0;
+    double sum_nll2 = 0;
+    double sum_kld  = 0;
+    double sum_kld2 = 0;
+    double sum_nll_diff  = 0;
+    double sum_nll_diff2 = 0;
+    size_t count = 0;
+};
+
+static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
+    float max_logit = logits[0];
+    for (int i = 1; i < n_vocab; ++i) {
+        max_logit = std::max(max_logit, logits[i]);
+    }
+    double sum_exp = 0.0;
+    for (int i = 0; i < n_vocab; ++i) {
+        sum_exp += expf(logits[i] - max_logit);
+    }
+    const float log_sum_exp = log(sum_exp);
+    const float * d = (const float *)base_log_prob;
+    const float scale = d[0];
+    const float min_log_prob = d[1];
+    base_log_prob += 4;
+    float nll = max_logit + log_sum_exp - logits[tok];
+    kld.sum_nll  += nll;
+    kld.sum_nll2 += nll*nll;
+    nll += (scale*base_log_prob[tok] + min_log_prob);
+    kld.sum_nll_diff  += nll;
+    kld.sum_nll_diff2 += nll*nll;
+    max_logit += log_sum_exp;
+    double sum = 0;
+    for (int i = 0; i < n_vocab; ++i) {
+        const float p_log_base = scale*base_log_prob[i] + min_log_prob;
+        if (p_log_base > -16.f) {
+            const float p_base = expf(p_log_base);
+            sum += p_base * (p_log_base - logits[i] + max_logit);
+        }
+    }
+    kld.sum_kld  += sum;
+    kld.sum_kld2 += sum*sum;
+    ++kld.count;
+}
+
+static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
+        std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
+    std::mutex mutex;
+    const int nv = 2*((n_vocab + 1)/2) + 4;
+    int counter = 0;
+    auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
+        kl_divergence_result local_kld;
+        while (true) {
+            std::unique_lock<std::mutex> lock(mutex);
+            int i = counter++;
+            if (i >= n_token) {
+                kld.sum_nll  += local_kld.sum_nll;
+                kld.sum_nll2 += local_kld.sum_nll2;
+                kld.sum_kld  += local_kld.sum_kld;
+                kld.sum_kld2 += local_kld.sum_kld2;
+                kld.sum_nll_diff  += local_kld.sum_nll_diff;
+                kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
+                kld.count += local_kld.count;
+                break;
+            }
+            lock.unlock();
+            log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
+        }
+    };
+    for (auto & w : workers) {
+        w = std::thread(compute);
+    }
+    compute();
+    for (auto & w : workers) {
+        w.join();
+    }
+}
+
 static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
@@ -294,6 +439,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
     const int n_ctx = llama_n_ctx(ctx);
 
+    std::ofstream logits_stream;
+    if (!params.logits_file.empty()) {
+        logits_stream.open(params.logits_file.c_str());
+        if (!logits_stream.is_open()) {
+            fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
+            return {};
+        }
+        fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
+        logits_stream.write("_logits_", 8);
+        logits_stream.write((const char *)&n_ctx, sizeof(n_ctx));
+    }
+
     auto tim1 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
@@ -336,6 +493,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
 
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 
+    std::vector<uint16_t> log_probs;
+    if (!params.logits_file.empty()) {
+        logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
+        logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
+        logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
+        const int nv = 2*((n_vocab + 1)/2) + 4;
+        log_probs.resize(n_ctx * nv);
+    }
+
     for (int i = 0; i < n_chunk; ++i) {
         const int start =     i * n_ctx;
         const int end   = start + n_ctx;
@@ -398,8 +564,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         // process the entire prompt.
         const int first = n_ctx/2;
         const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
-        process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
-                       workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
+        if (!params.logits_file.empty()) {
+            process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+                    workers, log_probs, nll, nll2);
+        } else {
+            process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+                    workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
+        }
         count += n_ctx - first - 1;
 
         // perplexity is e^(average negative log-likelihood)
@@ -1414,6 +1585,148 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
     printf("\n");
 }
 
+static void kl_divergence(llama_context * ctx, const gpt_params & params) {
+    if (params.logits_file.empty()) {
+        fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
+        return;
+    }
+    std::ifstream in(params.logits_file.c_str(), std::ios::binary);
+    if (!in) {
+        fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str());
+        return;
+    }
+    {
+        char check[9]; check[8] = 0;
+        in.read(check, 8);
+        if (in.fail() || strncmp("_logits_", check, 8) != 0) {
+            fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
+            return;
+        }
+    }
+
+    uint32_t n_ctx;
+    in.read((char *)&n_ctx, sizeof(n_ctx));
+    if (n_ctx > llama_n_ctx(ctx)) {
+        fprintf(stderr, "%s: %s has been computed with %d, while the current context is %d. Increase it with -c and retry\n",
+                __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
+    }
+
+    int n_vocab, n_chunk;
+    in.read((char *)&n_vocab, sizeof(n_vocab));
+    in.read((char *)&n_chunk, sizeof(n_chunk));
+    if (in.fail()) {
+        fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
+        return;
+    }
+    if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
+        fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
+    }
+
+    std::vector<llama_token> tokens(n_ctx * n_chunk);
+    if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
+        fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
+        return;
+    }
+
+    const int n_batch = params.n_batch;
+    const int num_batches = (n_ctx + n_batch - 1)/n_batch;
+    const int nv = 2*((n_vocab + 1)/2) + 4;
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+
+    std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
+    std::vector<float> logits;
+    if (num_batches > 1) {
+        logits.reserve(n_ctx * n_vocab);
+    }
+
+    std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
+
+    auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
+        if (count < 1) {
+            return std::make_pair(0., 0.);
+        }
+        double f = sum/count;
+        double df = sum2/count - f*f;
+        df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.;
+        return std::make_pair(f, df);
+    };
+
+    kl_divergence_result kld;
+
+    for (int i = 0; i < n_chunk; ++i) {
+        const int start =     i * n_ctx;
+        const int end   = start + n_ctx;
+
+        const auto t_start = std::chrono::high_resolution_clock::now();
+
+        if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
+            fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i);
+            return;
+        }
+
+        // clear the KV cache
+        llama_kv_cache_clear(ctx);
+
+        for (int j = 0; j < num_batches; ++j) {
+            const int batch_start = start + j * n_batch;
+            const int batch_size  = std::min(end - batch_start, n_batch);
+
+            // save original token and restore it after eval
+            const auto token_org = tokens[batch_start];
+
+            // add BOS token for the first batch of each chunk
+            if (add_bos && j == 0) {
+                tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
+            }
+
+            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
+                fprintf(stderr, "%s : failed to eval\n", __func__);
+                return;
+            }
+
+            // restore the original token in case it was set to BOS
+            tokens[batch_start] = token_org;
+
+            if (num_batches > 1) {
+                const auto * batch_logits = llama_get_logits(ctx);
+                logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            }
+        }
+
+        const auto t_end = std::chrono::high_resolution_clock::now();
+
+        if (i == 0) {
+            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
+            fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
+            int total_seconds = (int)(t_total * n_chunk);
+            if (total_seconds >= 60*60) {
+                fprintf(stderr, "%d hours ", total_seconds / (60*60));
+                total_seconds = total_seconds % (60*60);
+            }
+            fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
+
+            printf("\nchunk        PPL          ln(PPL(Q)/PPL(base))          KL-Divergence\n");
+        }
+
+        const int first = n_ctx/2;
+        const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
+        process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+                workers, log_probs_uint16, kld);
+
+        auto ppl           = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
+        auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
+        auto kl_div        = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
+
+        printf("%4d    %10.4lf    %10.5lf ± %10.5f    %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
+                log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
+
+        fflush(stdout);
+
+        logits.clear();
+    }
+    printf("\n");
+
+}
 
 int main(int argc, char ** argv) {
     gpt_params params;
@@ -1476,6 +1789,8 @@ int main(int argc, char ** argv) {
         winogrande_score(ctx, params);
     } else if (params.multiple_choice) {
         multiple_choice_score(ctx, params);
+    } else if (params.kl_divergence) {
+        kl_divergence(ctx, params);
     } else {
         results = perplexity(ctx, params);
     }