params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
+ } else if (arg == "--repeat-last-n") {
+ params.repeat_last_n = std::stof(argv[++i]);
+ } else if (arg == "--repeat-penalty") {
+ params.repeat_penalty = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
+ fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
+ fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
test_gpt_tokenizer(vocab, params.token_test);
}
+ if (params.repeat_last_n == -1) {
+ params.repeat_last_n = model.hparams.n_ctx;
+ }
+ printf("\n");
+ printf("%s: temp = %.3f\n", __func__, params.temp);
+ printf("%s: top_k = %d\n", __func__, params.top_k);
+ printf("%s: top_p = %.3f\n", __func__, params.top_p);
+ printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n);
+ printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
+
int n_past = 0;
int64_t t_sample_us = 0;
std::vector<float> logits;
+ std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
{
const int64_t t_start_sample_us = ggml_time_us();
- id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
-
+ id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}
// add it to the context
embd.push_back(id);
+
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(id);
} else {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
+
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(embd_inp[k]);
+
if (embd.size() >= params.n_batch) {
break;
}