]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : add "--prompt" command line argument (#90)
authorGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 17:43:16 +0000 (19:43 +0200)
committerGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 17:43:16 +0000 (19:43 +0200)
This allows to provide an initial prompt to be used at the start of the
processing.

examples/main/main.cpp

index 4071bd28fe75d8dd11deffa84754487788ee7e65..3ef576dddec67d5306c760b248dd153ddfb97f68 100644 (file)
@@ -73,8 +73,9 @@ struct whisper_params {
     bool print_colors  = false;
     bool no_timestamps = false;
 
-    std::string language  = "en";
-    std::string model     = "models/ggml-base.en.bin";
+    std::string language = "en";
+    std::string prompt   = "";
+    std::string model    = "models/ggml-base.en.bin";
 
     std::vector<std::string> fname_inp = {};
 };
@@ -113,6 +114,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-pc"   || arg == "--print-colors")  { params.print_colors  = true; }
         else if (arg == "-nt"   || arg == "--no-timestamps") { params.no_timestamps = true; }
         else if (arg == "-l"    || arg == "--language")      { params.language      = argv[++i]; }
+        else if (                  arg == "--prompt")        { params.prompt        = argv[++i]; }
         else if (arg == "-m"    || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")          { params.fname_inp.push_back(argv[++i]); }
         else {
@@ -150,6 +152,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -pc,      --print-colors  [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
     fprintf(stderr, "  -nt,      --no-timestamps [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                                params.language.c_str());
+    fprintf(stderr, "            --prompt PROMPT [%-7s] initial prompt\n",                                 params.prompt.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "\n");
@@ -462,6 +465,22 @@ int main(int argc, char ** argv) {
         return 3;
     }
 
+    // initial prompt
+    std::vector<whisper_token> prompt_tokens;
+
+    if (params.prompt.size() > 0) {
+        prompt_tokens.resize(1024);
+        prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
+
+        fprintf(stderr, "\n");
+        fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
+        fprintf(stderr, "initial tokens: [ ");
+        for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
+            fprintf(stderr, "%d ", prompt_tokens[i]);
+        }
+        fprintf(stderr, "]\n");
+    }
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
 
@@ -577,7 +596,6 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "\n");
         }
 
-
         // run the inference
         {
             whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@@ -599,6 +617,9 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
+            wparams.prompt_tokens    = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens  = prompt_tokens.size() == 0 ? 0       : prompt_tokens.size();
+
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
             // this callback is called on each new segment