]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Allow using prompt files (#59)
authorBen Garney <redacted>
Sun, 12 Mar 2023 20:28:36 +0000 (13:28 -0700)
committerGitHub <redacted>
Sun, 12 Mar 2023 20:28:36 +0000 (22:28 +0200)
utils.cpp

index 5435d474757bd4aff61b14032128ba8d61789dd2..13d4aa0f8bebaae2b063cd4050be92d0ee5588d4 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -4,6 +4,10 @@
 #include <cstring>
 #include <fstream>
 #include <regex>
+#include <iostream>
+#include <iterator>
+#include <string>
+#include <math.h>
 
  #if defined(_MSC_VER) || defined(__MINGW32__)
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -21,6 +25,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.n_threads = std::stoi(argv[++i]);
         } else if (arg == "-p" || arg == "--prompt") {
             params.prompt = argv[++i];
+        } else if (arg == "-f" || arg == "--file") {
+
+            std::ifstream file(argv[++i]);
+
+            std::copy(std::istreambuf_iterator<char>(file),
+                    std::istreambuf_iterator<char>(),
+                    back_inserter(params.prompt));
+                
         } else if (arg == "-n" || arg == "--n_predict") {
             params.n_predict = std::stoi(argv[++i]);
         } else if (arg == "--top_k") {
@@ -59,6 +71,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -f FNAME, --file FNAME\n");
+    fprintf(stderr, "                        prompt file to start generation.\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);