]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
common : fix param parsing (#391)
authorReza Rezvan <redacted>
Sun, 23 Jul 2023 15:12:47 +0000 (17:12 +0200)
committerGitHub <redacted>
Sun, 23 Jul 2023 15:12:47 +0000 (18:12 +0300)
examples/common.cpp
examples/starcoder/main.cpp

index 57f5039f7d9505ba53ee9230cbde88378e0c2d91..c839251869f84ce02b0f4e955553a6f8f822541e 100644 (file)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+// Function to check if the next argument exists
+std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
+    if (i + 1 < argc && argv[i + 1][0] != '-') {
+        return argv[++i];
+    } else {
+        fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
+        gpt_print_usage(argc, argv, params);
+        exit(0);
+    }
+}
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
 
         if (arg == "-s" || arg == "--seed") {
-            params.seed = std::stoi(argv[++i]);
+            params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-t" || arg == "--threads") {
-            params.n_threads = std::stoi(argv[++i]);
+            params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
-            params.n_gpu_layers = std::stoi(argv[++i]);
+            params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-p" || arg == "--prompt") {
-            params.prompt = argv[++i];
+            params.prompt = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-n" || arg == "--n_predict") {
-            params.n_predict = std::stoi(argv[++i]);
+            params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--top_k") {
-            params.top_k = std::max(1, std::stoi(argv[++i]));
+            params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--top_p") {
-            params.top_p = std::stof(argv[++i]);
+            params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--temp") {
-            params.temp = std::stof(argv[++i]);
+            params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--repeat-last-n") {
-            params.repeat_last_n = std::stof(argv[++i]);
+            params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--repeat-penalty") {
-            params.repeat_penalty = std::stof(argv[++i]);
+            params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-b" || arg == "--batch_size") {
-            params.n_batch = std::stoi(argv[++i]);
+            params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
-            params.model = argv[++i];
+            params.model = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
         } else if (arg == "-ip" || arg == "--interactive-port") {
             params.interactive = true;
-            params.interactive_port = std::stoi(argv[++i]);
+            params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             exit(0);
         } else if (arg == "-f" || arg == "--file") {
-            if (++i > argc) {
-                fprintf(stderr, "Invalid file param");
-                break;
-            }
+            get_next_arg(i, argc, argv, arg, params);
             std::ifstream file(argv[i]);
             if (!file) {
                 fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -72,7 +80,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 params.prompt.pop_back();
             }
         } else if (arg == "-tt" || arg == "--token_test") {
-            params.token_test = argv[++i];
+            params.token_test = get_next_arg(i, argc, argv, arg, params);
         }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
index d84e36634f82a0349a4b6e93915aaa56dff1ca8a..7bbd9c1593ae867acde6d14da787a7e9b814c261 100644 (file)
@@ -746,7 +746,6 @@ int main(int argc, char ** argv) {
     const int64_t t_main_start_us = ggml_time_us();
 
     gpt_params params;
-    params.model = "models/gpt-2-117M/ggml-model.bin";
 
     if (gpt_params_parse(argc, argv, params) == false) {
         return 1;