]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : deduplicated code for probability distribution access (#6240)
authorMinsoo Cheong <redacted>
Sun, 24 Mar 2024 08:54:07 +0000 (17:54 +0900)
committerGitHub <redacted>
Sun, 24 Mar 2024 08:54:07 +0000 (10:54 +0200)
* sampling: remove duplicated code for probability distribution access

* free original_logits

* fix original_logits allocation

* fixes based on review @cebtenzzre

* change function name to `llama_sampling_prepare`

common/sampling.cpp
common/sampling.h
examples/speculative/speculative.cpp
retrieval [new file with mode: 0755]

index 5a54509827cbf37840215f7bb87d3856be01bfe8..45d68b26c2b93f5e007617ad1499c305f04b2bf3 100644 (file)
@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
                   bool is_resampling) {  // Add a parameter to indicate if we are resampling
     const llama_sampling_params & params = ctx_sampling->params;
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
-
     const float   temp            = params.temp;
-    const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
-    const float   penalty_repeat  = params.penalty_repeat;
-    const float   penalty_freq    = params.penalty_freq;
-    const float   penalty_present = params.penalty_present;
     const int     mirostat        = params.mirostat;
     const float   mirostat_tau    = params.mirostat_tau;
     const float   mirostat_eta    = params.mirostat_eta;
-    const bool    penalize_nl     = params.penalize_nl;
 
-    auto & prev = ctx_sampling->prev;
-    auto & cur  = ctx_sampling->cur;
-
-    llama_token id = 0;
-
-    // Get a pointer to the logits
-    float * logits = llama_get_logits_ith(ctx_main, idx);
-
-    // Declare original_logits at the beginning of the function scope
     std::vector<float> original_logits;
-
+    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
     if (!is_resampling) {
-        // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
-        original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
-    }
-
-    // apply params.logit_bias map
-    for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
-        logits[it->first] += it->second;
-    }
-
-    if (ctx_cfg) {
-        float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
-        llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
-    }
-
-    cur.clear();
-
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-    }
-
-    llama_token_data_array cur_p = { cur.data(), cur.size(), false };
-
-    // apply penalties
-    const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
-    const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
-    if (penalty_tokens_used_size) {
-        const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
-
-        llama_sample_repetition_penalties(ctx_main, &cur_p,
-                penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
-                penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
-
-        if (!penalize_nl) {
-            for (size_t idx = 0; idx < cur_p.size; idx++) {
-                if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
-                    cur_p.data[idx].logit = nl_logit;
-                    break;
-                }
-            }
-        }
-    }
-
-    // If we are in the resampling phase, apply grammar checks before sampling logic
-    if (is_resampling && ctx_sampling->grammar != NULL) {
-        llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+        GGML_ASSERT(!original_logits.empty());
     }
+    llama_token id = 0;
+    // Get a pointer to the logits
+    float * logits = llama_get_logits_ith(ctx_main, idx);
 
     if (temp < 0.0) {
         // greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
     return id;
 }
 
-static llama_token_data_array llama_sample_probability_distribution_impl(
+static llama_token_data_array llama_sampling_prepare_impl(
                   struct llama_sampling_context * ctx_sampling,
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_cfg,
-                  const int idx) {
+                  const int idx,
+                  bool apply_grammar,
+                  std::vector<float> * original_logits) {
     const llama_sampling_params & params = ctx_sampling->params;
 
     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
     const float   penalty_repeat  = params.penalty_repeat;
     const float   penalty_freq    = params.penalty_freq;
     const float   penalty_present = params.penalty_present;
+
     const bool    penalize_nl     = params.penalize_nl;
 
     auto & prev = ctx_sampling->prev;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
     // Get a pointer to the logits
     float * logits = llama_get_logits_ith(ctx_main, idx);
 
-    // Declare original_logits at the beginning of the function scope
-    std::vector<float> original_logits;
+    if (apply_grammar && original_logits != NULL) {
+        // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
+        *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+    }
 
     // apply params.logit_bias map
     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
         }
     }
 
-    // apply grammar checks
-    if (ctx_sampling->grammar != NULL) {
+    // apply grammar checks before sampling logic
+    if (apply_grammar && ctx_sampling->grammar != NULL) {
         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
     }
 
-    llama_sample_softmax(ctx_main, &cur_p);
     return cur_p;
 }
 
@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
 }
 
-llama_token_data_array llama_sampling_probability_distribution(
+llama_token_data_array llama_sampling_prepare(
                   struct llama_sampling_context * ctx_sampling,
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_cfg,
-                  const int idx) {
-    return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+                  const int idx,
+                  bool apply_grammar,
+                  std::vector<float> * original_logits) {
+    return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
 }
 
 void llama_sampling_accept(
index 79a998be8e408d7214b95f475856dacd809a48da..56ed991b8478aeac56f3bb3ea0dd449d925a4b47 100644 (file)
@@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
         struct llama_context * ctx_cfg,
         int idx = 0);
 
-// returns the probability that token of given id will be sampled
-llama_token_data_array llama_sampling_probability_distribution(
+// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
+llama_token_data_array llama_sampling_prepare(
         struct llama_sampling_context * ctx_sampling,
         struct llama_context * ctx_main,
         struct llama_context * ctx_cfg,
-        int idx = 0);
+        int idx = 0,
+        bool apply_grammar = true,
+        std::vector<float> * original_logits = nullptr);
 
 void llama_sampling_accept(
         struct llama_sampling_context * ctx_sampling,
index e991b884607132e62319b683c075e706540d4892..8b31b678a6849a4764aaf4de25d69db842213c39 100644 (file)
@@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
                 if (params.sparams.temp > 0) {
                     // stochastic verification
 
-                    llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+                    llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
+                    llama_sample_softmax(ctx_tgt, &dist_tgt);
                     float p_tgt = 0, p_dft = 0;
 
                     // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
diff --git a/retrieval b/retrieval
new file mode 100755 (executable)
index 0000000..dd31789
Binary files /dev/null and b/retrieval differ