]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
common : allow prompts to be loaded from file (#102)
authorNevin <redacted>
Sat, 13 May 2023 08:41:43 +0000 (10:41 +0200)
committerGitHub <redacted>
Sat, 13 May 2023 08:41:43 +0000 (11:41 +0300)
* common: allow prompts to be loaded from file

* common : extra help for -f

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/common.cpp

index 081301195a496b79af88942921bdde591d2bdf57..75c443e20dafa3ab93782b9d54220fe09f4947a0 100644 (file)
@@ -38,6 +38,20 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             exit(0);
+        } else if (arg == "-f" || arg == "--file") {
+            if (++i > argc) {
+                fprintf(stderr, "Invalid file param");
+                break;
+            }
+            std::ifstream file(argv[i]);
+            if (!file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                break;
+            }
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
+            if (params.prompt.back() == '\n') {
+                params.prompt.pop_back();
+            }
         } else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             gpt_print_usage(argc, argv, params);
@@ -57,6 +71,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -f FNAME, --file FNAME\n");
+    fprintf(stderr, "                        load prompt from a file\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);