]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor samplers internal implementation (#9370)
authorslaren <redacted>
Sun, 8 Sep 2024 13:52:07 +0000 (15:52 +0200)
committerGitHub <redacted>
Sun, 8 Sep 2024 13:52:07 +0000 (15:52 +0200)
src/llama-impl.h
src/llama-sampling.cpp
src/llama-sampling.h
tests/test-sampling.cpp

index fa2e09e1f688e6b5c632af9baa9d7bef800738cc..87012617feed146d83b096b22ec6b14840464e96 100644 (file)
@@ -101,6 +101,10 @@ struct ring_buffer {
     }
 
     void push_back(const T & value) {
+        if (capacity == 0) {
+            throw std::runtime_error("ring buffer: capacity is zero");
+        }
+
         if (sz == capacity) {
             // advance the start when buffer is full
             first = (first + 1) % capacity;
index 1661d9a83ec80ea52f9818d85dec9a174c9b8853..41f48ec2867797e64ab983d6855d3ea2cc0741db 100644 (file)
 #include <unordered_map>
 
 static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
+#if 1
     probs.resize(cur_p->size);
     for (size_t i = 0; i < cur_p->size; ++i) {
         probs[i] = cur_p->data[i].p;
     }
 
     std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
+#else
+    // avoid the copy with a custom iterator
+    #pragma GCC diagnostic push
+    #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
+
+    struct probs_iterator {
+        typedef std::input_iterator_tag iterator_category;
+        typedef float value_type;
+        typedef float * pointer;
+        typedef float & reference;
+        typedef size_t difference_type;
+
+        const llama_token_data_array * data;
+        size_t i;
+
+        bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
+        bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
+        float operator*() const { return data->data[i].p; }
+        probs_iterator & operator++() { ++i; return *this; }
+        probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
+    };
+    #pragma GCC diagnostic pop
+
+    std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
+
+    GGML_UNUSED(probs);
+#endif
 
     return dist(rng);
 }
@@ -138,301 +166,6 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
     cur_p->size = k;
 }
 
-static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
-    if (p >= 1.0f) {
-        return;
-    }
-
-    llama_sampler_softmax_impl(cur_p);
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = cur_p->size;
-
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        cum_sum += cur_p->data[i].p;
-
-        // Check if the running sum is at least p or if we have kept at least min_keep tokens
-        // we set the last index to i+1 to indicate that the current iterate should be included in the set
-        if (cum_sum >= p && i + 1 >= min_keep) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the top-p tokens
-    cur_p->size = last_idx;
-}
-
-static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
-    if (p <= 0.0f || !cur_p->size) {
-        return;
-    }
-
-    bool min_p_applied = false;
-
-    // if the cur_p aren't sorted, try the unsorted implementation first
-    if (!cur_p->sorted) {
-        std::vector<llama_token_data> filtered_tokens;
-
-        float max_logit = -FLT_MAX;
-        for (size_t i = 0; i < cur_p->size; ++i) {
-            max_logit = std::max(max_logit, cur_p->data[i].logit);
-        }
-        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
-
-        for (size_t i = 0; i < cur_p->size; ++i) {
-            if (cur_p->data[i].logit >= min_logit) {
-                filtered_tokens.push_back(cur_p->data[i]);
-            }
-        }
-
-        // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= min_keep) {
-            memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
-            cur_p->size = filtered_tokens.size();
-            min_p_applied = true;
-        }
-    }
-
-    // if the cur_p are sorted or the unsorted implementation failed, use this implementation
-    if (!min_p_applied) {
-        // Sort the logits in descending order
-        if (!cur_p->sorted) {
-            std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
-                return a.logit > b.logit;
-            });
-            cur_p->sorted = true;
-        }
-
-        const float min_logit = cur_p->data[0].logit + logf(p); // min logit for p_i >= p * p_max
-        size_t i = 1; // first token always matches
-
-        for (; i < cur_p->size; ++i) {
-            if (cur_p->data[i].logit < min_logit && i >= min_keep) {
-                break; // prob too small
-            }
-        }
-
-        // Resize the output vector to keep only the matching tokens
-        cur_p->size = i;
-    }
-}
-
-static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) {
-    if (z >= 1.0f || cur_p->size <= 2) {
-        return;
-    }
-
-    llama_sampler_softmax_impl(cur_p);
-
-    // Compute the first and second derivatives
-    std::vector<float> first_derivatives(cur_p->size - 1);
-    std::vector<float> second_derivatives(cur_p->size - 2);
-
-    for (size_t i = 0; i < first_derivatives.size(); ++i) {
-        first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
-    }
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
-    }
-
-    // Calculate absolute value of second derivatives
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = std::abs(second_derivatives[i]);
-    }
-
-    // Normalize the second derivatives
-    {
-        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
-
-        if (second_derivatives_sum > 1e-6f) {
-            for (float & value : second_derivatives) {
-                value /= second_derivatives_sum;
-            }
-        } else {
-            for (float & value : second_derivatives) {
-                value = 1.0f / second_derivatives.size();
-            }
-        }
-    }
-
-    float cum_sum = 0.0f;
-    size_t last_idx = cur_p->size;
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        cum_sum += second_derivatives[i];
-
-        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
-        if (cum_sum > z && i >= min_keep) {
-            last_idx = i;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the tokens above the tail location
-    cur_p->size = last_idx;
-}
-
-static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
-    // Reference implementation:
-    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
-    if (p >= 1.0f) {
-        return;
-    }
-
-    // Compute the softmax of logits and calculate entropy
-    llama_sampler_softmax_impl(cur_p);
-
-    float entropy = 0.0f;
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
-    }
-
-    // Compute the absolute difference between negative log probability and entropy for each candidate
-    std::vector<float> shifted_scores;
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
-        shifted_scores.push_back(shifted_score);
-    }
-
-    // Sort tokens based on the shifted_scores and their corresponding indices
-    std::vector<size_t> indices(cur_p->size);
-    std::iota(indices.begin(), indices.end(), 0);
-
-    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
-        return shifted_scores[a] < shifted_scores[b];
-    });
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = indices.size();
-
-    for (size_t i = 0; i < indices.size(); ++i) {
-        size_t idx = indices[i];
-        cum_sum += cur_p->data[idx].p;
-
-        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > p && i >= min_keep - 1) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the locally typical tokens
-    std::vector<llama_token_data> cur_p_new;
-    for (size_t i = 0; i < last_idx; ++i) {
-        size_t idx = indices[i];
-        cur_p_new.push_back(cur_p->data[idx]);
-    }
-
-    // Replace the data in cur_p with the cur_p_new data
-    std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
-    cur_p->size = cur_p_new.size();
-    cur_p->sorted = false;
-}
-
-static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) {
-    // no need to do anything if there is only one (or zero) candidates
-    if (cur_p->size <= 1) {
-        return;
-    }
-
-    // Calculate maximum possible entropy
-    float max_entropy = -logf(1.0f / cur_p->size);
-
-    llama_sampler_softmax_impl(cur_p);
-
-    // Calculate entropy of the softmax probabilities
-    float entropy = 0.0f;
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        float prob = cur_p->data[i].p;
-        if (prob > 0.0f) { // Ensure no log(0)
-            entropy -= prob * logf(prob);
-        }
-    }
-
-    // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
-    float normalized_entropy = entropy / max_entropy;
-
-    // Map the normalized entropy to the desired temperature range using the power function
-    float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
-
-#ifdef DEBUG
-    LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
-    LLAMA_LOG_INFO("Entropy: %f\n", entropy);
-    LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
-    LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
-    LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
-    LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
-#endif
-
-    // Apply the dynamically calculated temperature scaling
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        cur_p->data[i].logit /= dyn_temp;
-    }
-
-    // Re-compute softmax probabilities after scaling logits with dynamic temperature
-    const double max_l_double = cur_p->data[0].logit;
-
-    double cum_sum_double = 0.0;
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        double p = exp(cur_p->data[i].logit - max_l_double);
-        cur_p->data[i].p = p; // Store the scaled probability
-        cum_sum_double += p;
-    }
-
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
-    }
-
-#ifdef DEBUG
-    // Print the updated top 25 probabilities after temperature scaling
-    LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
-    for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
-        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
-    }
-#endif
-}
-
-static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        cur_p->data[i].logit /= temp;
-    }
-}
-
-static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) {
-    llama_grammar_apply_impl(grammar, cur_p);
-}
-
-void llama_sampler_penalties_impl(
-       llama_token_data_array * cur_p,
-        const llama_token_cnt & token_count,
-                        float   penalty_repeat,
-                        float   penalty_freq,
-                        float   penalty_present) {
-    // Apply frequency and presence penalties to the cur_p
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        const auto token_iter = token_count.find(cur_p->data[i].id);
-        if (token_iter == token_count.end()) {
-            continue;
-        }
-
-        const int count = token_iter->second;
-
-        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
-        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
-        if (cur_p->data[i].logit <= 0) {
-            cur_p->data[i].logit *= penalty_repeat;
-        } else {
-            cur_p->data[i].logit /= penalty_repeat;
-        }
-
-        cur_p->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
-    }
-
-    cur_p->sorted = false;
-}
-
 // llama_sampler API
 
 const char * llama_sampler_name(const struct llama_sampler * smpl) {
@@ -600,17 +333,23 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
 
 // greedy
 
+static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
+    return "greedy";
+}
+
+static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+    cur_p->selected = 0;
+    for (size_t i = 1; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
+            cur_p->selected = i;
+        }
+    }
+}
+
 static struct llama_sampler_i llama_sampler_greedy_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; },
+    /* .name   = */ llama_sampler_greedy_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
-        cur_p->selected = 0;
-        for (size_t i = 1; i < cur_p->size; ++i) {
-            if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
-                cur_p->selected = i;
-            }
-        }
-    },
+    /* .apply  = */ llama_sampler_greedy_apply,
     /* .reset  = */ nullptr,
     /* .clone  = */ nullptr,
     /* .free   = */ nullptr,
@@ -633,30 +372,45 @@ struct llama_sampler_dist {
     std::vector<float> probs; // work array
 };
 
-static struct llama_sampler_i llama_sampler_dist_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; },
-    /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * ctx = (llama_sampler_dist *) smpl->ctx;
-        cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
-    },
-    /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
-        auto * result = llama_sampler_init_dist(ctx->seed);
+static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
+    return "dist";
+}
 
-        // copy the state
-        {
-            auto * result_ctx = (llama_sampler_dist *) result->ctx;
+static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_dist *) smpl->ctx;
+    cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+}
 
-            result_ctx->rng = ctx->rng;
-        }
+static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
+    auto * result = llama_sampler_init_dist(ctx->seed);
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_dist *) smpl->ctx;
-    },
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_dist *) result->ctx;
+
+        result_ctx->rng = ctx->rng;
+    }
+
+    return result;
+}
+
+static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_dist *) smpl->ctx;
+    ctx->rng = std::mt19937(ctx->seed);
+}
+
+static void llama_sampler_dist_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_dist *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_dist_i = {
+    /* .name   = */ llama_sampler_dist_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_dist_apply,
+    /* .reset  = */ llama_sampler_dist_reset,
+    /* .clone  = */ llama_sampler_dist_clone,
+    /* .free   = */ llama_sampler_dist_free,
 };
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@@ -672,12 +426,18 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
 
 // softmax
 
+static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
+    return "softmax";
+}
+
+static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+    llama_sampler_softmax_impl(cur_p);
+}
+
 static struct llama_sampler_i llama_sampler_softmax_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; },
+    /* .name   = */ llama_sampler_softmax_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
-        llama_sampler_softmax_impl(cur_p);
-    },
+    /* .apply  = */ llama_sampler_softmax_apply,
     /* .reset  = */ nullptr,
     /* .clone  = */ nullptr,
     /* .free   = */ nullptr,
@@ -696,21 +456,31 @@ struct llama_sampler_top_k {
     const int32_t k;
 };
 
+static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
+    return "top-k";
+}
+
+static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+    llama_sampler_top_k_impl(cur_p, ctx->k);
+}
+
+static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
+    return llama_sampler_init_top_k(ctx->k);
+}
+
+static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_k *) smpl->ctx;
+}
+
 static struct llama_sampler_i llama_sampler_top_k_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; },
+    /* .name   = */ llama_sampler_top_k_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
-        llama_sampler_top_k_impl(cur_p, ctx->k);
-    },
+    /* .apply  = */ llama_sampler_top_k_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
-        return llama_sampler_init_top_k(ctx->k);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_top_k *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_top_k_clone,
+    /* .free   = */ llama_sampler_top_k_free,
 };
 
 struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
@@ -729,21 +499,54 @@ struct llama_sampler_top_p {
     const size_t min_keep;
 };
 
+static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
+    return "top-p";
+}
+
+static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+
+    if (ctx->p >= 1.0f) {
+        return;
+    }
+
+    llama_sampler_softmax_impl(cur_p);
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = cur_p->size;
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cum_sum += cur_p->data[i].p;
+
+        // Check if the running sum is at least p or if we have kept at least min_keep tokens
+        // we set the last index to i+1 to indicate that the current iterate should be included in the set
+        if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
+            last_idx = i + 1;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the top-p tokens
+    cur_p->size = last_idx;
+}
+
+static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
+    return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
+}
+
+static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_p *) smpl->ctx;
+}
+
 static struct llama_sampler_i llama_sampler_top_p_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; },
+    /* .name   = */ llama_sampler_top_p_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
-        llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep);
-    },
+    /* .apply  = */ llama_sampler_top_p_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
-        return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_top_p *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_top_p_clone,
+    /* .free   = */ llama_sampler_top_p_free,
 };
 
 struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
@@ -763,21 +566,83 @@ struct llama_sampler_min_p {
     const size_t min_keep;
 };
 
+static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
+    return "min-p";
+}
+
+static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+
+    if (ctx->p <= 0.0f || !cur_p->size) {
+        return;
+    }
+
+    bool min_p_applied = false;
+
+    // if the cur_p aren't sorted, try the unsorted implementation first
+    if (!cur_p->sorted) {
+        std::vector<llama_token_data> filtered_tokens;
+
+        float max_logit = -FLT_MAX;
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            max_logit = std::max(max_logit, cur_p->data[i].logit);
+        }
+        const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
+
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(cur_p->data[i]);
+            }
+        }
+
+        // if we have enough values the operation was a success
+        if (filtered_tokens.size() >= ctx->min_keep) {
+            memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            cur_p->size = filtered_tokens.size();
+            min_p_applied = true;
+        }
+    }
+
+    // if the cur_p are sorted or the unsorted implementation failed, use this implementation
+    if (!min_p_applied) {
+        // Sort the logits in descending order
+        if (!cur_p->sorted) {
+            std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
+                return a.logit > b.logit;
+            });
+            cur_p->sorted = true;
+        }
+
+        const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
+        size_t i = 1; // first token always matches
+
+        for (; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
+                break; // prob too small
+            }
+        }
+
+        // Resize the output vector to keep only the matching tokens
+        cur_p->size = i;
+    }
+}
+
+static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
+    return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
+}
+
+static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_min_p *) smpl->ctx;
+}
+
 static struct llama_sampler_i llama_sampler_min_p_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; },
+    /* .name   = */ llama_sampler_min_p_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
-        llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep);
-    },
+    /* .apply  = */ llama_sampler_min_p_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
-        return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_min_p *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_min_p_clone,
+    /* .free   = */ llama_sampler_min_p_free,
 };
 
 struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
@@ -797,21 +662,82 @@ struct llama_sampler_tail_free {
     const size_t min_keep;
 };
 
+static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
+    return "tail-free";
+}
+
+static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
+
+    if (ctx->z >= 1.0f || cur_p->size <= 2) {
+        return;
+    }
+
+    llama_sampler_softmax_impl(cur_p);
+
+    // Compute the first and second derivatives
+    std::vector<float> first_derivatives(cur_p->size - 1);
+    std::vector<float> second_derivatives(cur_p->size - 2);
+
+    for (size_t i = 0; i < first_derivatives.size(); ++i) {
+        first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
+    }
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
+    }
+
+    // Calculate absolute value of second derivatives
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        second_derivatives[i] = std::abs(second_derivatives[i]);
+    }
+
+    // Normalize the second derivatives
+    {
+        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
+
+        if (second_derivatives_sum > 1e-6f) {
+            for (float & value : second_derivatives) {
+                value /= second_derivatives_sum;
+            }
+        } else {
+            for (float & value : second_derivatives) {
+                value = 1.0f / second_derivatives.size();
+            }
+        }
+    }
+
+    float cum_sum = 0.0f;
+    size_t last_idx = cur_p->size;
+    for (size_t i = 0; i < second_derivatives.size(); ++i) {
+        cum_sum += second_derivatives[i];
+
+        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
+        if (cum_sum > ctx->z && i >= ctx->min_keep) {
+            last_idx = i;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the tokens above the tail location
+    cur_p->size = last_idx;
+}
+
+static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
+    return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
+}
+
+static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_tail_free *) smpl->ctx;
+}
+
 static struct llama_sampler_i llama_sampler_tail_free_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; },
+    /* .name   = */ llama_sampler_tail_free_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
-        llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
-    },
+    /* .apply  = */ llama_sampler_tail_free_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
-        return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_tail_free *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_tail_free_clone,
+    /* .free   = */ llama_sampler_tail_free_free,
 };
 
 struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
@@ -831,21 +757,86 @@ struct llama_sampler_typical {
     const size_t min_keep;
 };
 
+static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
+    return "typical";
+}
+
+static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+
+    // Reference implementation:
+    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
+    if (ctx->p >= 1.0f) {
+        return;
+    }
+
+    // Compute the softmax of logits and calculate entropy
+    llama_sampler_softmax_impl(cur_p);
+
+    float entropy = 0.0f;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
+    }
+
+    // Compute the absolute difference between negative log probability and entropy for each candidate
+    std::vector<float> shifted_scores;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
+        shifted_scores.push_back(shifted_score);
+    }
+
+    // Sort tokens based on the shifted_scores and their corresponding indices
+    std::vector<size_t> indices(cur_p->size);
+    std::iota(indices.begin(), indices.end(), 0);
+
+    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
+        return shifted_scores[a] < shifted_scores[b];
+    });
+
+    // Compute the cumulative probabilities
+    float cum_sum = 0.0f;
+    size_t last_idx = indices.size();
+
+    for (size_t i = 0; i < indices.size(); ++i) {
+        size_t idx = indices[i];
+        cum_sum += cur_p->data[idx].p;
+
+        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
+        if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
+            last_idx = i + 1;
+            break;
+        }
+    }
+
+    // Resize the output vector to keep only the locally typical tokens
+    std::vector<llama_token_data> cur_p_new;
+    for (size_t i = 0; i < last_idx; ++i) {
+        size_t idx = indices[i];
+        cur_p_new.push_back(cur_p->data[idx]);
+    }
+
+    // Replace the data in cur_p with the cur_p_new data
+    std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
+    cur_p->size = cur_p_new.size();
+    cur_p->sorted = false;
+}
+
+static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
+    return llama_sampler_init_typical(ctx->p, ctx->min_keep);
+}
+
+static void llama_sampler_typical_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_typical *) smpl->ctx;
+}
+
 static struct llama_sampler_i llama_sampler_typical_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; },
+    /* .name   = */ llama_sampler_typical_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_typical *) smpl->ctx;
-        llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep);
-    },
+    /* .apply  = */ llama_sampler_typical_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
-        return llama_sampler_init_typical(ctx->p, ctx->min_keep);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_typical *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_typical_clone,
+    /* .free   = */ llama_sampler_typical_free,
 };
 
 struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
@@ -858,27 +849,39 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
     };
 }
 
-// temp
+// temp
+
+struct llama_sampler_temp {
+    const float temp;
+};
+
+static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
+    return "temp";
+}
+
+static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_temp *) smpl->ctx;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].logit /= ctx->temp;
+    }
+}
+
+static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
+    return llama_sampler_init_temp(ctx->temp);
+}
 
-struct llama_sampler_temp {
-    const float temp;
-};
+static void llama_sampler_temp_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_temp *) smpl->ctx;
+}
 
 static struct llama_sampler_i llama_sampler_temp_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; },
+    /* .name   = */ llama_sampler_temp_name,
     /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_temp *) smpl->ctx;
-        llama_sampler_temp_impl(cur_p, ctx->temp);
-    },
+    /* .apply  = */ llama_sampler_temp_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
-        return llama_sampler_init_temp(ctx->temp);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_temp *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_temp_clone,
+    /* .free   = */ llama_sampler_temp_free,
 };
 
 struct llama_sampler * llama_sampler_init_temp(float temp) {
@@ -898,28 +901,100 @@ struct llama_sampler_temp_ext {
     const float exponent;
 };
 
-static struct llama_sampler_i llama_sampler_temp_ext_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; },
-    /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
-        if (ctx->delta > 0) {
-            const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
-            const float temp_max = ctx->temp + ctx->delta;
+static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
+    return "temp-ext";
+}
 
-            llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent);
-        } else {
-            llama_sampler_temp_impl(cur_p, ctx->temp);
+static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+    if (ctx->delta > 0) {
+        const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
+        const float max_temp = ctx->temp + ctx->delta;
+        float exponent_val = ctx->exponent;
+
+        // no need to do anything if there is only one (or zero) candidates
+        if (cur_p->size <= 1) {
+            return;
         }
-    },
+
+        // Calculate maximum possible entropy
+        float max_entropy = -logf(1.0f / cur_p->size);
+
+        llama_sampler_softmax_impl(cur_p);
+
+        // Calculate entropy of the softmax probabilities
+        float entropy = 0.0f;
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            float prob = cur_p->data[i].p;
+            if (prob > 0.0f) { // Ensure no log(0)
+                entropy -= prob * logf(prob);
+            }
+        }
+
+        // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
+        float normalized_entropy = entropy / max_entropy;
+
+        // Map the normalized entropy to the desired temperature range using the power function
+        float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
+
+    #ifdef DEBUG
+        LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
+        LLAMA_LOG_INFO("Entropy: %f\n", entropy);
+        LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
+        LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
+        LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
+        LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
+    #endif
+
+        // Apply the dynamically calculated temperature scaling
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].logit /= dyn_temp;
+        }
+
+        // Re-compute softmax probabilities after scaling logits with dynamic temperature
+        const double max_l_double = cur_p->data[0].logit;
+
+        double cum_sum_double = 0.0;
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            double p = exp(cur_p->data[i].logit - max_l_double);
+            cur_p->data[i].p = p; // Store the scaled probability
+            cum_sum_double += p;
+        }
+
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+        }
+
+    #ifdef DEBUG
+        // Print the updated top 25 probabilities after temperature scaling
+        LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
+        for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
+            LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
+        }
+    #endif
+    } else {
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].logit /= ctx->temp;
+        }
+    }
+}
+
+static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
+    return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
+}
+
+static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_temp_ext *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_temp_ext_i = {
+    /* .name   = */ llama_sampler_temp_ext_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_temp_ext_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
-        return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_temp_ext *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_temp_ext_clone,
+    /* .free   = */ llama_sampler_temp_ext_free,
 };
 
 struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
@@ -952,65 +1027,77 @@ struct llama_sampler_mirostat {
     std::vector<float> probs;
 };
 
-static struct llama_sampler_i llama_sampler_mirostat_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
-    /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
+    return "mirostat";
+}
 
-        llama_sampler_softmax_impl(cur_p);
+static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
 
-        // Estimate s_hat using the most probable m tokens
-        float s_hat = 0.0;
-        float sum_ti_bi = 0.0;
-        float sum_ti_sq = 0.0;
-        for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
-            float t_i = logf(float(i + 2) / float(i + 1));
-            float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
-            sum_ti_bi += t_i * b_i;
-            sum_ti_sq += t_i * t_i;
-        }
-        s_hat = sum_ti_bi / sum_ti_sq;
+    llama_sampler_softmax_impl(cur_p);
 
-        // Compute k from the estimated s_hat and target surprise value
-        float epsilon_hat = s_hat - 1;
-        float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
+    // Estimate s_hat using the most probable m tokens
+    float s_hat = 0.0;
+    float sum_ti_bi = 0.0;
+    float sum_ti_sq = 0.0;
+    for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
+        float t_i = logf(float(i + 2) / float(i + 1));
+        float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
+        sum_ti_bi += t_i * b_i;
+        sum_ti_sq += t_i * t_i;
+    }
+    s_hat = sum_ti_bi / sum_ti_sq;
 
-        llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
-        llama_sampler_softmax_impl(cur_p);
+    // Compute k from the estimated s_hat and target surprise value
+    float epsilon_hat = s_hat - 1;
+    float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
 
-        const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
+    llama_sampler_softmax_impl(cur_p);
 
-        cur_p->selected = idx;
+    const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
 
-        float observed_surprise = -log2f(cur_p->data[idx].p);
-        float e = observed_surprise - ctx->tau;
+    cur_p->selected = idx;
 
-        // Update mu using the learning rate and error
-        ctx->mu = ctx->mu - ctx->eta * e;
-    },
-    /* .reset  = */ [](struct llama_sampler * smpl) {
-        auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
-        ctx->mu = 2.0f*ctx->tau;
-        ctx->rng = std::mt19937(ctx->seed);
-    },
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
-        auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
+    float observed_surprise = -log2f(cur_p->data[idx].p);
+    float e = observed_surprise - ctx->tau;
 
-        // copy the state
-        {
-            auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
+    // Update mu using the learning rate and error
+    ctx->mu = ctx->mu - ctx->eta * e;
+}
 
-            result_ctx->mu  = ctx->mu;
-            result_ctx->rng = ctx->rng;
-        }
+static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
+    auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_mirostat *) smpl->ctx;
-    },
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+        result_ctx->mu  = ctx->mu;
+        result_ctx->rng = ctx->rng;
+    }
+
+    return result;
+}
+
+static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+    ctx->mu = 2.0f*ctx->tau;
+    ctx->rng = std::mt19937(ctx->seed);
+}
+
+static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_mirostat *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_mirostat_i = {
+    /* .name   = */ llama_sampler_mirostat_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_mirostat_apply,
+    /* .reset  = */ llama_sampler_mirostat_reset,
+    /* .clone  = */ llama_sampler_mirostat_clone,
+    /* .free   = */ llama_sampler_mirostat_free,
 };
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
@@ -1044,59 +1131,71 @@ struct llama_sampler_mirostat_v2 {
     std::vector<float> probs;
 };
 
-static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
-    /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
+    return "mirostat-v2";
+}
 
-        llama_sampler_softmax_impl(cur_p);
+static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
 
-        // Truncate the words with surprise values greater than mu
-        cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
-            return -log2f(candidate.p) > ctx->mu;
-        }));
+    llama_sampler_softmax_impl(cur_p);
 
-        if (cur_p->size == 0) {
-            cur_p->size = 1;
-        }
+    // Truncate the words with surprise values greater than mu
+    cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
+        return -log2f(candidate.p) > ctx->mu;
+    }));
 
-        // Normalize the probabilities of the remaining words
-        llama_sampler_softmax_impl(cur_p);
+    if (cur_p->size == 0) {
+        cur_p->size = 1;
+    }
 
-        const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    // Normalize the probabilities of the remaining words
+    llama_sampler_softmax_impl(cur_p);
 
-        cur_p->selected = idx;
+    const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
 
-        float observed_surprise = -log2f(cur_p->data[idx].p);
-        float e = observed_surprise - ctx->tau;
+    cur_p->selected = idx;
 
-        // Update mu using the learning rate and error
-        ctx->mu = ctx->mu - ctx->eta * e;
-    },
-    /* .reset  = */ [](struct llama_sampler * smpl) {
-        auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
-        ctx->mu = 2.0f*ctx->tau;
-        ctx->rng = std::mt19937(ctx->seed);
-    },
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
+    float observed_surprise = -log2f(cur_p->data[idx].p);
+    float e = observed_surprise - ctx->tau;
 
-        auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
+    // Update mu using the learning rate and error
+    ctx->mu = ctx->mu - ctx->eta * e;
+}
 
-        // copy the state
-        {
-            auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
+static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+    ctx->mu = 2.0f*ctx->tau;
+    ctx->rng = std::mt19937(ctx->seed);
+}
 
-            result_ctx->mu  = ctx->mu;
-            result_ctx->rng = ctx->rng;
-        }
+static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_mirostat_v2 *) smpl->ctx;
-    },
+    auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
+
+        result_ctx->mu  = ctx->mu;
+        result_ctx->rng = ctx->rng;
+    }
+
+    return result;
+}
+
+static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_mirostat_v2 *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
+    /* .name   = */ llama_sampler_mirostat_v2_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_mirostat_v2_apply,
+    /* .reset  = */ llama_sampler_mirostat_v2_reset,
+    /* .clone  = */ llama_sampler_mirostat_v2_clone,
+    /* .free   = */ llama_sampler_mirostat_v2_free,
 };
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1124,59 +1223,73 @@ struct llama_sampler_grammar {
     struct llama_grammar * grammar;
 };
 
-static struct llama_sampler_i llama_sampler_grammar_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; },
-    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
-        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
-        if (ctx->grammar) {
-            llama_grammar_accept_impl(*ctx->grammar, token);
-        }
-    },
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
-        if (ctx->grammar) {
-            llama_sampler_grammar_impl(cur_p, *ctx->grammar);
-        }
-    },
-    /* .reset  = */ [](struct llama_sampler * smpl) {
-        auto * ctx = (llama_sampler_grammar *) smpl->ctx;
-        if (!ctx->grammar) {
-            return;
-        }
+static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
+    return "grammar";
+}
 
-        auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (ctx->grammar) {
+        llama_grammar_accept_impl(*ctx->grammar, token);
+    }
+}
 
-        llama_grammar_free_impl(ctx->grammar);
-        ctx->grammar = grammar_new;
-    },
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
+static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (ctx->grammar) {
+        llama_grammar_apply_impl(*ctx->grammar, cur_p);
+    }
+}
 
-        auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    if (!ctx->grammar) {
+        return;
+    }
 
-        // copy the state
-        {
-            auto * result_ctx = (llama_sampler_grammar *) result->ctx;
+    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
 
-            if (ctx->grammar) {
-                result_ctx->grammar_str  = ctx->grammar_str;
-                result_ctx->grammar_root = ctx->grammar_root;
+    llama_grammar_free_impl(ctx->grammar);
+    ctx->grammar = grammar_new;
+}
 
-                result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
-            }
-        }
+static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+    auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_grammar *) result->ctx;
 
         if (ctx->grammar) {
-            llama_grammar_free_impl(ctx->grammar);
+            result_ctx->grammar_str  = ctx->grammar_str;
+            result_ctx->grammar_root = ctx->grammar_root;
+
+            result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
         }
+    }
 
-        delete ctx;
-    },
+    return result;
+}
+
+static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
+    const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+
+    if (ctx->grammar) {
+        llama_grammar_free_impl(ctx->grammar);
+    }
+
+    delete ctx;
+}
+
+static struct llama_sampler_i llama_sampler_grammar_i = {
+    /* .name   = */ llama_sampler_grammar_name,
+    /* .accept = */ llama_sampler_grammar_accept_impl,
+    /* .apply  = */ llama_sampler_grammar_apply,
+    /* .reset  = */ llama_sampler_grammar_reset,
+    /* .clone  = */ llama_sampler_grammar_clone,
+    /* .free   = */ llama_sampler_grammar_free,
 };
 
 struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
@@ -1222,106 +1335,144 @@ struct llama_sampler_penalties {
     ring_buffer<llama_token> prev;
 };
 
-static struct llama_sampler_i llama_sampler_penalties_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; },
-    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
-        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
-        if (ctx->prev.size()) {
-            ctx->prev.push_back(token);
-        }
-    },
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
-
-        if (ctx->ignore_eos) {
-            assert(ctx->special_eos_id >= 0);
-
-            // optimistically check if the candidates are not yet sorted/shuffled/truncated
-            if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
-                cur_p->data[ctx->special_eos_id].logit = -INFINITY;
-            } else {
-                // else, search for the special EOS token
-                for (size_t i = 0; i < cur_p->size; ++i) {
-                    if (cur_p->data[i].id == ctx->special_eos_id) {
-                        cur_p->data[i].logit = -INFINITY;
-                        break;
-                    }
+static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
+    return "penalties";
+}
+
+static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    if (ctx->penalty_last_n == 0) {
+        return;
+    }
+
+    ctx->prev.push_back(token);
+}
+
+static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+
+    if (ctx->ignore_eos) {
+        assert(ctx->special_eos_id >= 0);
+
+        // optimistically check if the candidates are not yet sorted/shuffled/truncated
+        if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
+            cur_p->data[ctx->special_eos_id].logit = -INFINITY;
+        } else {
+            // else, search for the special EOS token
+            for (size_t i = 0; i < cur_p->size; ++i) {
+                if (cur_p->data[i].id == ctx->special_eos_id) {
+                    cur_p->data[i].logit = -INFINITY;
+                    break;
                 }
             }
         }
+    }
 
-        if ((ctx->penalty_last_n == 0) ||
-            (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
-            return;
-        }
+    if ((ctx->penalty_last_n == 0) ||
+        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
+        return;
+    }
 
-        bool nl_found = false;
-        size_t nl_idx = 0;
-        float nl_logit = -INFINITY;
-        if (!ctx->penalize_nl) {
-            assert(ctx->linefeed_id >= 0);
-
-            // optimistically check if the candidates are not yet sorted/shuffled/truncated
-            if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
-                nl_found = true;
-                nl_idx = ctx->linefeed_id;
-                nl_logit = cur_p->data[ctx->linefeed_id].logit;
-            } else {
-                // else, search for the linefeed token
-                for (size_t i = 0; i < cur_p->size; ++i) {
-                    if (cur_p->data[i].id == ctx->linefeed_id) {
-                        nl_found = true;
-                        nl_idx = i;
-                        nl_logit = cur_p->data[i].logit;
-                        break;
-                    }
+    bool nl_found = false;
+    size_t nl_idx = 0;
+    float nl_logit = -INFINITY;
+    if (!ctx->penalize_nl) {
+        assert(ctx->linefeed_id >= 0);
+
+        // optimistically check if the candidates are not yet sorted/shuffled/truncated
+        if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
+            nl_found = true;
+            nl_idx = ctx->linefeed_id;
+            nl_logit = cur_p->data[ctx->linefeed_id].logit;
+        } else {
+            // else, search for the linefeed token
+            for (size_t i = 0; i < cur_p->size; ++i) {
+                if (cur_p->data[i].id == ctx->linefeed_id) {
+                    nl_found = true;
+                    nl_idx = i;
+                    nl_logit = cur_p->data[i].logit;
+                    break;
                 }
             }
         }
+    }
 
-        // Create a frequency map to count occurrences of each token in last_tokens
-        // TODO: optimize this by maintaining the token count in the sampler context
-        llama_token_cnt token_count;
-        for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
-            token_count[ctx->prev.rat(i)]++;
-        }
+    // Create a frequency map to count occurrences of each token in last_tokens
+    // TODO: optimize this by maintaining the token count in the sampler context
+    using llama_token_cnt = std::unordered_map<llama_token, int>;
+    llama_token_cnt token_count;
 
-        llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);
+    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+        token_count[ctx->prev.rat(i)]++;
+    }
 
-        if (!ctx->penalize_nl && nl_found) {
-            // restore the logit of the newline token if it was penalized
-            cur_p->data[nl_idx].logit = nl_logit;
+    // Apply frequency and presence penalties to the cur_p
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        const auto token_iter = token_count.find(cur_p->data[i].id);
+        if (token_iter == token_count.end()) {
+            continue;
         }
-    },
-    /* .reset  = */ [](struct llama_sampler * smpl) {
-        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
-        ctx->prev.clear();
-    },
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
-        auto * result = llama_sampler_init_penalties(
-                ctx->n_vocab,
-                ctx->special_eos_id,
-                ctx->linefeed_id,
-                ctx->penalty_last_n,
-                ctx->penalty_repeat,
-                ctx->penalty_freq,
-                ctx->penalty_present,
-                ctx->penalize_nl,
-                ctx->ignore_eos);
-
-        // copy the state
-        {
-            auto * result_ctx = (llama_sampler_penalties *) result->ctx;
-
-            result_ctx->prev = ctx->prev;
+
+        const int count = token_iter->second;
+
+        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
+        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
+        if (cur_p->data[i].logit <= 0) {
+            cur_p->data[i].logit *= ctx->penalty_repeat;
+        } else {
+            cur_p->data[i].logit /= ctx->penalty_repeat;
         }
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_penalties *) smpl->ctx;
-    },
+        cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
+    }
+
+    cur_p->sorted = false;
+
+    if (!ctx->penalize_nl && nl_found) {
+        // restore the logit of the newline token if it was penalized
+        cur_p->data[nl_idx].logit = nl_logit;
+    }
+}
+
+static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    ctx->prev.clear();
+}
+
+static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
+    auto * result = llama_sampler_init_penalties(
+            ctx->n_vocab,
+            ctx->special_eos_id,
+            ctx->linefeed_id,
+            ctx->penalty_last_n,
+            ctx->penalty_repeat,
+            ctx->penalty_freq,
+            ctx->penalty_present,
+            ctx->penalize_nl,
+            ctx->ignore_eos);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_penalties *) result->ctx;
+
+        result_ctx->prev = ctx->prev;
+    }
+
+    return result;
+}
+
+static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_penalties *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_penalties_i = {
+    /* .name   = */ llama_sampler_penalties_name,
+    /* .accept = */ llama_sampler_penalties_accept,
+    /* .apply  = */ llama_sampler_penalties_apply,
+    /* .reset  = */ llama_sampler_penalties_reset,
+    /* .clone  = */ llama_sampler_penalties_clone,
+    /* .free   = */ llama_sampler_penalties_free,
 };
 
 struct llama_sampler * llama_sampler_init_penalties(
@@ -1335,11 +1486,11 @@ struct llama_sampler * llama_sampler_init_penalties(
         bool penalize_nl,
         bool ignore_eos) {
     if (linefeed_id == LLAMA_TOKEN_NULL) {
-        penalize_nl = false;
+        penalize_nl = true;
     }
 
     if (special_eos_id == LLAMA_TOKEN_NULL) {
-        ignore_eos = true;
+        ignore_eos = false;
     }
 
     return new llama_sampler {
@@ -1369,41 +1520,50 @@ struct llama_sampler_logit_bias {
     std::vector<llama_logit_bias> to_search;
 };
 
-static struct llama_sampler_i llama_sampler_logit_bias_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; },
-    /* .accept = */ nullptr,
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
+    return "logit-bias";
+}
 
-        ctx->to_search.clear();
+static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
 
-        // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
-        for (const auto & lb : ctx->logit_bias) {
-            if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
-                cur_p->data[lb.token].logit += lb.bias;
-            } else {
-                ctx->to_search.push_back(lb);
-            }
+    ctx->to_search.clear();
+
+    // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
+    for (const auto & lb : ctx->logit_bias) {
+        if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
+            cur_p->data[lb.token].logit += lb.bias;
+        } else {
+            ctx->to_search.push_back(lb);
         }
+    }
 
-        // search for the remaining candidates that were not found in the previous step
-        for (size_t i = 0; i < cur_p->size; ++i) {
-            for (const auto & lb : ctx->to_search) {
-                if (cur_p->data[i].id == lb.token) {
-                    cur_p->data[i].logit += lb.bias;
-                    break;
-                }
+    // search for the remaining candidates that were not found in the previous step
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        for (const auto & lb : ctx->to_search) {
+            if (cur_p->data[i].id == lb.token) {
+                cur_p->data[i].logit += lb.bias;
+                break;
             }
         }
-    },
+    }
+}
+static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
+    return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
+}
+
+static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_logit_bias *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_logit_bias_i = {
+    /* .name   = */ llama_sampler_logit_bias_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_logit_bias_apply,
     /* .reset  = */ nullptr,
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
-        return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        delete (llama_sampler_logit_bias *) smpl->ctx;
-    },
+    /* .clone  = */ llama_sampler_logit_bias_clone,
+    /* .free   = */ llama_sampler_logit_bias_free,
 };
 
 struct llama_sampler * llama_sampler_init_logit_bias(
index 137c0025ce0d89ea0aa765ae4f3b57604907062b..d90b147130e4b10d691d797dc99137569ca22a6e 100644 (file)
@@ -23,16 +23,6 @@ struct llama_sampler_chain {
     mutable int32_t n_sample;
 };
 
-using llama_token_cnt = std::unordered_map<llama_token, int>;
-
-// TODO: tmp exposed until test-sampling is fixed
-void llama_sampler_penalties_impl(
-       llama_token_data_array * cur_p,
-        const llama_token_cnt & token_count,
-                        float   penalty_repeat,
-                        float   penalty_freq,
-                        float   penalty_present);
-
 struct llama_sampler * llama_sampler_init_grammar_impl(
         const struct llama_vocab & vocab,
                       const char * grammar_str,
index cc4882d37579a6327c07256dbca362304fd9795d..37400c179e9bdd2a6c96e61d0cfdfbe38fa5e2c7 100644 (file)
@@ -148,15 +148,17 @@ static void test_penalties(
         cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_cnt token_count;
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+
+    auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
+
     for (size_t i = 0; i < last_tokens.size(); i++) {
-        token_count[last_tokens[i]]++;
+        llama_sampler_accept(sampler, last_tokens[i]);
     }
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
     APPLY(llama_sampler_init_softmax(), &cur_p);
     DUMP(&cur_p);
-    llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
+    APPLY(sampler, &cur_p);
     APPLY(llama_sampler_init_softmax(), &cur_p);
     DUMP(&cur_p);