From: Georgi Gerganov Date: Fri, 20 Oct 2023 06:57:04 +0000 (+0300) Subject: gpt-2 : allow setting custom context size (i.e. large KV cache) X-Git-Tag: upstream/0.0.1642~1214 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=6a4e20f53037a61a2fa32cd462bb86dda4e7e99a;p=pkg%2Fggml%2Fsources%2Fggml gpt-2 : allow setting custom context size (i.e. large KV cache) --- diff --git a/examples/common.cpp b/examples/common.cpp index 42017055..d55708ad 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -58,6 +58,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-c" || arg == "--context") { + params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-m" || arg == "--model") { params.model = get_next_arg(i, argc, argv, arg, params); } else if (arg == "-i" || arg == "--interactive") { @@ -113,6 +115,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); + fprintf(stderr, " -c N, --context N context / KV cache size (default: %d)\n", params.n_ctx); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index ac059f3b..635afc24 100644 --- a/examples/common.h +++ b/examples/common.h @@ -15,11 +15,12 @@ // struct gpt_params { - int32_t seed = -1; // RNG seed + int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_predict = 200; // new tokens to predict - int32_t n_parallel = 1; // number of parallel streams - int32_t n_batch = 8; // batch size for prompt processing + int32_t n_predict = 200; // new tokens to predict + int32_t n_parallel = 1; // number of parallel streams + int32_t n_batch = 8; // batch size for prompt processing + int32_t n_ctx = 2048; // context size (this is the KV cache max size) // sampling parameters int32_t top_k = 40; diff --git a/examples/gpt-2/main-batched.cpp b/examples/gpt-2/main-batched.cpp index 119df33f..1c499bd5 100644 --- a/examples/gpt-2/main-batched.cpp +++ b/examples/gpt-2/main-batched.cpp @@ -144,7 +144,7 @@ struct gpt2_batch { }; // load the model's weights from a file -bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) { +bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) { printf("%s: loading model from '%s'\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); @@ -386,6 +386,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & } } + // override the default training context with the user-provided + model.hparams.n_ctx = n_ctx; + // key + value memory { const auto & hparams = model.hparams; @@ -1013,7 +1016,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) { + if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 0acb3a1b..5702fd8e 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -96,7 +96,7 @@ struct gpt2_model { }; // load the model's weights from a file -bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) { +bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) { printf("%s: loading model from '%s'\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); @@ -338,6 +338,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & } } + // override the default training context with the user-provided + model.hparams.n_ctx = n_ctx; + // key + value memory { const auto & hparams = model.hparams; @@ -859,7 +862,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) { + if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; }