]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CLI args use - instead of _, backwards compatible (#1416)
authorJohannes Gäßler <redacted>
Fri, 12 May 2023 14:34:55 +0000 (16:34 +0200)
committerGitHub <redacted>
Fri, 12 May 2023 14:34:55 +0000 (14:34 +0000)
examples/common.cpp

index f3085b08e5b25e2d086fd264174f8422b418377c..80e35d2e9cec8f90846f8bc09cb66b0b7df9575f 100644 (file)
@@ -91,9 +91,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     bool escape_prompt = false;
     std::string arg;
     gpt_params default_params;
+    const std::string arg_prefix = "--";
 
     for (int i = 1; i < argc; i++) {
         arg = argv[i];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
 
         if (arg == "-s" || arg == "--seed") {
 #if defined(GGML_USE_CUBLAS)
@@ -141,27 +145,27 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             if (params.prompt.back() == '\n') {
                 params.prompt.pop_back();
             }
-        } else if (arg == "-n" || arg == "--n_predict") {
+        } else if (arg == "-n" || arg == "--n-predict") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.n_predict = std::stoi(argv[i]);
-        } else if (arg == "--top_k") {
+        } else if (arg == "--top-k") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.top_k = std::stoi(argv[i]);
-        } else if (arg == "-c" || arg == "--ctx_size") {
+        } else if (arg == "-c" || arg == "--ctx-size") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.n_ctx = std::stoi(argv[i]);
-        } else if (arg == "--memory_f32") {
+        } else if (arg == "--memory-f32") {
             params.memory_f16 = false;
-        } else if (arg == "--top_p") {
+        } else if (arg == "--top-p") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -185,25 +189,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.typical_p = std::stof(argv[i]);
-        } else if (arg == "--repeat_last_n") {
+        } else if (arg == "--repeat-last-n") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.repeat_last_n = std::stoi(argv[i]);
-        } else if (arg == "--repeat_penalty") {
+        } else if (arg == "--repeat-penalty") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.repeat_penalty = std::stof(argv[i]);
-        } else if (arg == "--frequency_penalty") {
+        } else if (arg == "--frequency-penalty") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.frequency_penalty = std::stof(argv[i]);
-        } else if (arg == "--presence_penalty") {
+        } else if (arg == "--presence-penalty") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -215,19 +219,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.mirostat = std::stoi(argv[i]);
-        } else if (arg == "--mirostat_lr") {
+        } else if (arg == "--mirostat-lr") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.mirostat_eta = std::stof(argv[i]);
-        } else if (arg == "--mirostat_ent") {
+        } else if (arg == "--mirostat-ent") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.mirostat_tau = std::stof(argv[i]);
-        } else if (arg == "-b" || arg == "--batch_size") {
+        } else if (arg == "-b" || arg == "--batch-size") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -310,7 +314,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 invalid_param = true;
                 break;
             }
-        } else if (arg == "--n_parts") {
+        } else if (arg == "--n-parts") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -384,31 +388,31 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --in-suffix STRING    string to suffix after user inputs with (default: empty)\n");
     fprintf(stderr, "  -f FNAME, --file FNAME\n");
     fprintf(stderr, "                        prompt file to start generation.\n");
-    fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
-    fprintf(stderr, "  --top_k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
-    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
+    fprintf(stderr, "  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
+    fprintf(stderr, "  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
+    fprintf(stderr, "  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
     fprintf(stderr, "  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
     fprintf(stderr, "  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
-    fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
-    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
-    fprintf(stderr, "  --presence_penalty N  repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
-    fprintf(stderr, "  --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+    fprintf(stderr, "  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat-penalty N    penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+    fprintf(stderr, "  --presence-penalty N  repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
+    fprintf(stderr, "  --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
     fprintf(stderr, "  --mirostat N          use Mirostat sampling.\n");
     fprintf(stderr, "                        Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
     fprintf(stderr, "                        (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
-    fprintf(stderr, "  --mirostat_lr N       Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
-    fprintf(stderr, "  --mirostat_ent N      Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+    fprintf(stderr, "  --mirostat-lr N       Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
+    fprintf(stderr, "  --mirostat-ent N      Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
     fprintf(stderr, "  -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
     fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n");
     fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
     fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
-    fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n");
-    fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n");
+    fprintf(stderr, "  --memory-f32          use f32 instead of f16 for memory key+value\n");
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
-    fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
-    fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stderr, "  --n-parts N           number of model parts (default: -1 = determine from dimensions)\n");
+    fprintf(stderr, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n");
     fprintf(stderr, "  --keep                number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
     if (llama_mlock_supported()) {