}
// apply penalties
- if (!prev.empty()) {
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
- prev.data() + prev.size() - penalty_last_n,
- penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
slot->prompt = "";
}
+ slot->sparams.penalty_prompt_tokens.clear();
+ slot->sparams.use_penalty_prompt_tokens = false;
+ const auto &penalty_prompt = data.find("penalty_prompt");
+ if (penalty_prompt != data.end())
+ {
+ if (penalty_prompt->is_string())
+ {
+ const auto penalty_prompt_string = penalty_prompt->get<std::string>();
+ auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
+ slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
+ if (slot->params.n_predict > 0)
+ {
+ slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
+ }
+ slot->sparams.use_penalty_prompt_tokens = true;
+ }
+ else if (penalty_prompt->is_array())
+ {
+ const auto n_tokens = penalty_prompt->size();
+ slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
+ const int n_vocab = llama_n_vocab(model);
+ for (const auto &penalty_token : *penalty_prompt)
+ {
+ if (penalty_token.is_number_integer())
+ {
+ const auto tok = penalty_token.get<llama_token>();
+ if (tok >= 0 && tok < n_vocab)
+ {
+ slot->sparams.penalty_prompt_tokens.push_back(tok);
+ }
+ }
+ }
+ slot->sparams.use_penalty_prompt_tokens = true;
+ }
+ }
+
slot->sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false))
slot.generated_text += token_str;
slot.has_next_token = true;
+ if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
+ {
+ // we can change penalty_prompt_tokens because it is always created from scratch each request
+ slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
+ }
+
// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
+ {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
+ {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},