params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
+ } else if (arg == "-c" || arg == "--context") {
+ params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-m" || arg == "--model") {
params.model = get_next_arg(i, argc, argv, arg, params);
} else if (arg == "-i" || arg == "--interactive") {
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ fprintf(stderr, " -c N, --context N context / KV cache size (default: %d)\n", params.n_ctx);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
//
struct gpt_params {
- int32_t seed = -1; // RNG seed
+ int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
- int32_t n_predict = 200; // new tokens to predict
- int32_t n_parallel = 1; // number of parallel streams
- int32_t n_batch = 8; // batch size for prompt processing
+ int32_t n_predict = 200; // new tokens to predict
+ int32_t n_parallel = 1; // number of parallel streams
+ int32_t n_batch = 8; // batch size for prompt processing
+ int32_t n_ctx = 2048; // context size (this is the KV cache max size)
// sampling parameters
int32_t top_k = 40;
};
// load the model's weights from a file
-bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
}
}
+ // override the default training context with the user-provided
+ model.hparams.n_ctx = n_ctx;
+
// key + value memory
{
const auto & hparams = model.hparams;
{
const int64_t t_start_us = ggml_time_us();
- if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
+ if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
};
// load the model's weights from a file
-bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
}
}
+ // override the default training context with the user-provided
+ model.hparams.n_ctx = n_ctx;
+
// key + value memory
{
const auto & hparams = model.hparams;
{
const int64_t t_start_us = ggml_time_us();
- if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
+ if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}