From: Reza Rezvan Date: Sun, 23 Jul 2023 15:12:47 +0000 (+0200) Subject: common : fix param parsing (#391) X-Git-Tag: upstream/0.0.1642~1299 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=f1c5a11547b7ff77e6039b136d97a430cb1138c2;p=pkg%2Fggml%2Fsources%2Fggml common : fix param parsing (#391) --- diff --git a/examples/common.cpp b/examples/common.cpp index 57f5039f..c8392518 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -21,47 +21,55 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +// Function to check if the next argument exists +std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + return argv[++i]; + } else { + fprintf(stderr, "error: %s requires one argument.\n", flag.c_str()); + gpt_print_usage(argc, argv, params); + exit(0); + } +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "-s" || arg == "--seed") { - params.seed = std::stoi(argv[++i]); + params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-t" || arg == "--threads") { - params.n_threads = std::stoi(argv[++i]); + params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { - params.n_gpu_layers = std::stoi(argv[++i]); + params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-p" || arg == "--prompt") { - params.prompt = argv[++i]; + params.prompt = get_next_arg(i, argc, argv, arg, params); } else if (arg == "-n" || arg == "--n_predict") { - params.n_predict = std::stoi(argv[++i]); + params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--top_k") { - params.top_k = std::max(1, std::stoi(argv[++i])); + params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--top_p") { - params.top_p = std::stof(argv[++i]); + params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--temp") { - params.temp = std::stof(argv[++i]); + params.temp = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--repeat-last-n") { - params.repeat_last_n = std::stof(argv[++i]); + params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--repeat-penalty") { - params.repeat_penalty = std::stof(argv[++i]); + params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-b" || arg == "--batch_size") { - params.n_batch = std::stoi(argv[++i]); + params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-m" || arg == "--model") { - params.model = argv[++i]; + params.model = get_next_arg(i, argc, argv, arg, params); } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "-ip" || arg == "--interactive-port") { params.interactive = true; - params.interactive_port = std::stoi(argv[++i]); + params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); } else if (arg == "-f" || arg == "--file") { - if (++i > argc) { - fprintf(stderr, "Invalid file param"); - break; - } + get_next_arg(i, argc, argv, arg, params); std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -72,7 +80,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.prompt.pop_back(); } } else if (arg == "-tt" || arg == "--token_test") { - params.token_test = argv[++i]; + params.token_test = get_next_arg(i, argc, argv, arg, params); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index d84e3663..7bbd9c15 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -746,7 +746,6 @@ int main(int argc, char ** argv) { const int64_t t_main_start_us = ggml_time_us(); gpt_params params; - params.model = "models/gpt-2-117M/ggml-model.bin"; if (gpt_params_parse(argc, argv, params) == false) { return 1;