]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : add alpaca support (#668)
authorEvan Jones <redacted>
Wed, 29 Mar 2023 20:01:14 +0000 (16:01 -0400)
committerGitHub <redacted>
Wed, 29 Mar 2023 20:01:14 +0000 (23:01 +0300)
examples/talk-llama/prompts/talk-alpaca.txt [new file with mode: 0644]
examples/talk-llama/talk-llama.cpp

diff --git a/examples/talk-llama/prompts/talk-alpaca.txt b/examples/talk-llama/prompts/talk-alpaca.txt
new file mode 100644 (file)
index 0000000..79b9610
--- /dev/null
@@ -0,0 +1,23 @@
+Below is an instruction that describes a task. Write a response that appropriately completes the request.
+
+### Instruction:
+
+Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
+{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
+There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
+The transcript only includes text, it does not include markup like HTML and Markdown.
+{1} responds with short and concise answers.
+
+### Response:
+
+{0}{4} Hello, {1}!
+{1}{4} Hello {0}! How may I help you today?
+{0}{4} What time is it?
+{1}{4} It is {2} o'clock.
+{0}{4} What year is it?
+{1}{4} We are in {3}.
+{0}{4} What is a cat?
+{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
+{0}{4} Name a color.
+{1}{4} Blue
+{0}{4}
index c7690f14ca05b160b83512ea2772dd3fd8ab320c..af5309cb4221c3c3bb4bae9d6bf39a9b33619313 100644 (file)
@@ -33,6 +33,8 @@ struct whisper_params {
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
 
+    int32_t n_parts_llama = -1;
+
     float vad_thold    = 0.6f;
     float freq_thold   = 100.0f;
 
@@ -41,12 +43,14 @@ struct whisper_params {
     bool print_special = false;
     bool print_energy  = false;
     bool no_timestamps = true;
+    bool verbose_prompt = false;
 
     std::string person      = "Georgi";
     std::string language    = "en";
     std::string model_wsp   = "models/ggml-base.en.bin";
     std::string model_llama = "models/ggml-llama-7B.bin";
     std::string speak       = "./examples/talk/speak.sh";
+    std::string prompt      = "";
     std::string fname_out;
 };
 
@@ -67,15 +71,24 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ac"  || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
         else if (arg == "-vth" || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
         else if (arg == "-fth" || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
+        else if (arg == "--n-parts-llama")                  { params.n_parts_llama = std::stoi(argv[++i]); }
         else if (arg == "-su"  || arg == "--speed-up")      { params.speed_up      = true; }
         else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
         else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
         else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
+        else if (arg == "--verbose-prompt")                 { params.verbose_prompt = true; }
         else if (arg == "-p"   || arg == "--person")        { params.person        = argv[++i]; }
         else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
         else if (arg == "-mw"  || arg == "--model-whisper") { params.model_wsp     = argv[++i]; }
         else if (arg == "-ml"  || arg == "--model-llama")   { params.model_llama   = argv[++i]; }
         else if (arg == "-s"   || arg == "--speak")         { params.speak         = argv[++i]; }
+        else if (arg == "--prompt-file")                    {
+            std::ifstream file(argv[++i]);
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
+            if (params.prompt.back() == '\n') {
+                params.prompt.pop_back();
+            }
+        }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -108,7 +121,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -mw FILE, --model-whisper [%-7s] whisper model file\n",                          params.model_wsp.c_str());
     fprintf(stderr, "  -mg FILE, --model-llama   [%-7s] llama model file\n",                            params.model_llama.c_str());
+    fprintf(stderr, "  --n-parts-llama N         [%-7d] num parts in llama model file\n",               params.n_parts_llama);
     fprintf(stderr, "  -s FILE,  --speak TEXT    [%-7s] command for TTS\n",                             params.speak.c_str());
+    fprintf(stderr, "  --prompt-file FNAME       [%-7s] file with custom prompt to start dialog\n",     "");
+    fprintf(stderr, "  --verbose-prompt          [%-7s] print prompt at start\n",                       params.verbose_prompt ? "true" : "false");
     fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                       params.fname_out.c_str());
     fprintf(stderr, "\n");
 }
@@ -183,8 +199,7 @@ std::string transcribe(
 
 const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
 
-// need to have leading ' '
-const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
+const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
 {1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
 There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
 The transcript only includes text, it does not include markup like HTML and Markdown.
@@ -227,6 +242,7 @@ int main(int argc, char ** argv) {
     lparams.n_ctx      = 512;
     lparams.seed       = 1;
     lparams.f16_kv     = true;
+    lparams.n_parts    = params.n_parts_llama;
 
     struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
 
@@ -278,7 +294,10 @@ int main(int argc, char ** argv) {
     const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
 
     // construct the initial prompt for LLaMA inference
-    std::string prompt_llama = k_prompt_llama;
+    std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
+
+    // need to have leading ' '
+    prompt_llama.insert(0, 1, ' ');
 
     prompt_llama = ::replace(prompt_llama, "{0}", params.person);
     prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
@@ -323,9 +342,11 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    //fprintf(stdout, "\n");
-    //fprintf(stdout, "%s", prompt_llama.c_str());
-    //fflush(stdout);
+    if (params.verbose_prompt) {
+        fprintf(stdout, "\n");
+        fprintf(stdout, "%s", prompt_llama.c_str());
+        fflush(stdout);
+    }
 
     printf("%s : done! start speaking in the microphone\n", __func__);
     printf("\n");