bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;
- const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
-
const float temp = params.temp;
- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
- const float penalty_repeat = params.penalty_repeat;
- const float penalty_freq = params.penalty_freq;
- const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
- const bool penalize_nl = params.penalize_nl;
- auto & prev = ctx_sampling->prev;
- auto & cur = ctx_sampling->cur;
-
- llama_token id = 0;
-
- // Get a pointer to the logits
- float * logits = llama_get_logits_ith(ctx_main, idx);
-
- // Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
-
+ auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
- // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
- original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
- }
-
- // apply params.logit_bias map
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
- logits[it->first] += it->second;
- }
-
- if (ctx_cfg) {
- float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
- llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
- }
-
- cur.clear();
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
-
- llama_token_data_array cur_p = { cur.data(), cur.size(), false };
-
- // apply penalties
- const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
- const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
- if (penalty_tokens_used_size) {
- const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
-
- llama_sample_repetition_penalties(ctx_main, &cur_p,
- penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
-
- if (!penalize_nl) {
- for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
- cur_p.data[idx].logit = nl_logit;
- break;
- }
- }
- }
- }
-
- // If we are in the resampling phase, apply grammar checks before sampling logic
- if (is_resampling && ctx_sampling->grammar != NULL) {
- llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ GGML_ASSERT(!original_logits.empty());
}
+ llama_token id = 0;
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
if (temp < 0.0) {
// greedy sampling, with probs
return id;
}
-static llama_token_data_array llama_sample_probability_distribution_impl(
+static llama_token_data_array llama_sampling_prepare_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
- const int idx) {
+ const int idx,
+ bool apply_grammar,
+ std::vector<float> * original_logits) {
const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
+
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
- // Declare original_logits at the beginning of the function scope
- std::vector<float> original_logits;
+ if (apply_grammar && original_logits != NULL) {
+ // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
+ *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+ }
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
}
}
- // apply grammar checks
- if (ctx_sampling->grammar != NULL) {
+ // apply grammar checks before sampling logic
+ if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
- llama_sample_softmax(ctx_main, &cur_p);
return cur_p;
}
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
-llama_token_data_array llama_sampling_probability_distribution(
+llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
- const int idx) {
- return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+ const int idx,
+ bool apply_grammar,
+ std::vector<float> * original_logits) {
+ return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
}
void llama_sampling_accept(