]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt-2 : allow setting custom context size (i.e. large KV cache)
authorGeorgi Gerganov <redacted>
Fri, 20 Oct 2023 06:57:04 +0000 (09:57 +0300)
committerGeorgi Gerganov <redacted>
Fri, 20 Oct 2023 06:57:04 +0000 (09:57 +0300)
examples/common.cpp
examples/common.h
examples/gpt-2/main-batched.cpp
examples/gpt-2/main.cpp

index 420170551145dc95f958fc648c27231f5a5e3872..d55708ad1bec29112feba2aaeef42b6c106dfe6d 100644 (file)
@@ -58,6 +58,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-b" || arg == "--batch_size") {
             params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
+        } else if (arg == "-c" || arg == "--context") {
+            params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
             params.model = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-i" || arg == "--interactive") {
@@ -113,6 +115,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
     fprintf(stderr, "  --repeat-penalty N    penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stderr, "  -c N, --context N     context / KV cache size (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "\n");
index ac059f3b79d63db80339cbe148bb21928c476356..635afc24f55e7e94519fd43e6b5f7f6fa5f3deab 100644 (file)
 //
 
 struct gpt_params {
-    int32_t seed       = -1;  // RNG seed
+    int32_t seed       = -1;   // RNG seed
     int32_t n_threads  = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t n_predict  = 200; // new tokens to predict
-    int32_t n_parallel = 1;   // number of parallel streams
-    int32_t n_batch    = 8;   // batch size for prompt processing
+    int32_t n_predict  = 200;  // new tokens to predict
+    int32_t n_parallel = 1;    // number of parallel streams
+    int32_t n_batch    = 8;    // batch size for prompt processing
+    int32_t n_ctx      = 2048; // context size (this is the KV cache max size)
 
     // sampling parameters
     int32_t top_k          = 40;
index 119df33f52a0f23fb74b223780737063be1315d1..1c499bd59d20d6bea2c5fde70bc1a9a3a02c4ba2 100644 (file)
@@ -144,7 +144,7 @@ struct gpt2_batch {
 };
 
 // load the model's weights from a file
-bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
     printf("%s: loading model from '%s'\n", __func__, fname.c_str());
 
     auto fin = std::ifstream(fname, std::ios::binary);
@@ -386,6 +386,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
+    // override the default training context with the user-provided
+    model.hparams.n_ctx = n_ctx;
+
     // key + value memory
     {
         const auto & hparams = model.hparams;
@@ -1013,7 +1016,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
+        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }
index 0acb3a1b1eeae5f2f87e2410e006007e79bccd9b..5702fd8efee6ecb3fec6ad25f5695d126cab47aa 100644 (file)
@@ -96,7 +96,7 @@ struct gpt2_model {
 };
 
 // load the model's weights from a file
-bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
     printf("%s: loading model from '%s'\n", __func__, fname.c_str());
 
     auto fin = std::ifstream(fname, std::ios::binary);
@@ -338,6 +338,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
+    // override the default training context with the user-provided
+    model.hparams.n_ctx = n_ctx;
+
     // key + value memory
     {
         const auto & hparams = model.hparams;
@@ -859,7 +862,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
+        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }