]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cmdline option for custom amount of model parts (--n_parts N) (#348)
authoranzz1 <redacted>
Tue, 21 Mar 2023 15:42:43 +0000 (17:42 +0200)
committerGitHub <redacted>
Tue, 21 Mar 2023 15:42:43 +0000 (17:42 +0200)
* cmdline option for custom amount of model parts (--n_parts N)

* Update main.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
main.cpp
utils.cpp
utils.h

index e97611e2882c60e0266f993f0b619ef2335aa4b2..662a2a79bc4c7b64faccd9c5ac2a57c4448242a1 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -90,7 +90,8 @@ struct llama_model {
 };
 
 // load the model's weights from a file
-bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
+
+bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
 
     std::vector<char> f_buf(1024*1024);
@@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
     }
 
     int n_ff = 0;
-    int n_parts = 0;
 
     // load hparams
     {
@@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
         hparams.n_ctx = n_ctx;
 
         n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
-        n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+
+        if (n_parts < 1) {
+            n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+        }
 
         fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
         fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
@@ -839,7 +842,7 @@ int main(int argc, char ** argv) {
     {
         const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
         const int64_t t_start_us = ggml_time_us();
-        if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
+        if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }
index 4843b4f557f12c1b277819292ac5279b0489073f..03ed9bc0658704f1bfaf59e7369c7bab6df7da51 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -74,6 +74,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.antiprompt.push_back(argv[++i]);
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
+        } else if (arg == "--n_parts") {
+            params.n_parts = std::stoi(argv[++i]);
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             exit(0);
@@ -116,6 +118,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
     fprintf(stderr, "  --memory_f16          use f16 instead of f32 for memory key+value\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
diff --git a/utils.h b/utils.h
index 4aa7c63b2cea3f5d9309f093a4d113581aec53c6..c7fce964b4e2d65bb40397a20cb309604110d59a 100644 (file)
--- a/utils.h
+++ b/utils.h
 //
 
 struct gpt_params {
-    int32_t seed          = -1; // RNG seed
+    int32_t seed          = -1;  // RNG seed
     int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict     = 128; // new tokens to predict
     int32_t repeat_last_n = 64;  // last n tokens to penalize
+    int32_t n_parts       = -1;  // amount of model parts (-1 = determine from model dimensions)
     int32_t n_ctx         = 512; //context size
 
     // sampling parameters