}
}
-llama_token llama_sampling_sample(
+static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
- const int idx) {
+ const int idx,
+ bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
llama_token id = 0;
+ // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
+ // Declare original_logits at the beginning of the function scope
+ std::vector<float> original_logits;
+
+ if (!is_resampling) {
+ // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
+ original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
+ }
+
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
}
- if (ctx_sampling->grammar != NULL) {
+ // If we are in the resampling phase, apply grammar checks before sampling logic
+ if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
}
}
+ if (ctx_sampling->grammar != NULL && !is_resampling) {
+ // Create an array with a single token data element for the sampled id
+ llama_token_data single_token_data = {id, logits[id], 0.0f};
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
+
+ // Apply grammar constraints to the single token
+ llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
+
+ // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
+ bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+
+ // If the token is not valid according to the grammar, perform resampling
+ if (!is_valid) {
+ LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
+
+ // Restore logits from the copy
+ std::copy(original_logits.begin(), original_logits.end(), logits);
+
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
+ }
+ }
+
return id;
}
+llama_token llama_sampling_sample(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ // Call the implementation function with is_resampling set to false by default
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
+}
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,