#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+// Function to check if the next argument exists
+std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
+ return argv[++i];
+ } else {
+ fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
+ gpt_print_usage(argc, argv, params);
+ exit(0);
+ }
+}
+
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-s" || arg == "--seed") {
- params.seed = std::stoi(argv[++i]);
+ params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-t" || arg == "--threads") {
- params.n_threads = std::stoi(argv[++i]);
+ params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
- params.n_gpu_layers = std::stoi(argv[++i]);
+ params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-p" || arg == "--prompt") {
- params.prompt = argv[++i];
+ params.prompt = get_next_arg(i, argc, argv, arg, params);
} else if (arg == "-n" || arg == "--n_predict") {
- params.n_predict = std::stoi(argv[++i]);
+ params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--top_k") {
- params.top_k = std::max(1, std::stoi(argv[++i]));
+ params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--top_p") {
- params.top_p = std::stof(argv[++i]);
+ params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--temp") {
- params.temp = std::stof(argv[++i]);
+ params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--repeat-last-n") {
- params.repeat_last_n = std::stof(argv[++i]);
+ params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--repeat-penalty") {
- params.repeat_penalty = std::stof(argv[++i]);
+ params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-b" || arg == "--batch_size") {
- params.n_batch = std::stoi(argv[++i]);
+ params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-m" || arg == "--model") {
- params.model = argv[++i];
+ params.model = get_next_arg(i, argc, argv, arg, params);
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "-ip" || arg == "--interactive-port") {
params.interactive = true;
- params.interactive_port = std::stoi(argv[++i]);
+ params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else if (arg == "-f" || arg == "--file") {
- if (++i > argc) {
- fprintf(stderr, "Invalid file param");
- break;
- }
+ get_next_arg(i, argc, argv, arg, params);
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
params.prompt.pop_back();
}
} else if (arg == "-tt" || arg == "--token_test") {
- params.token_test = argv[++i];
+ params.token_test = get_next_arg(i, argc, argv, arg, params);
}
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());