]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
command : always-prompt mode (#383)
authorDavid <redacted>
Sat, 7 Jan 2023 19:41:11 +0000 (20:41 +0100)
committerGitHub <redacted>
Sat, 7 Jan 2023 19:41:11 +0000 (21:41 +0200)
examples/command/command.cpp

index 3ea563add3055497eaf9c4d371ccd21d8d22820b..524ad67f7f6bfb5cabb79323325c52a33148383e 100644 (file)
@@ -11,6 +11,8 @@
 #include <SDL.h>
 #include <SDL_audio.h>
 
+#include <iostream>
+#include <sstream>
 #include <cassert>
 #include <cstdio>
 #include <fstream>
@@ -25,7 +27,7 @@
 struct whisper_params {
     int32_t n_threads  = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t prompt_ms  = 5000;
-    int32_t command_ms = 4000;
+    int32_t command_ms = 8000;
     int32_t capture_id = -1;
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
@@ -43,6 +45,7 @@ struct whisper_params {
     std::string model     = "models/ggml-base.en.bin";
     std::string fname_out;
     std::string commands;
+    std::string prompt;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -71,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-m"   || arg == "--model")         { params.model         = argv[++i]; }
         else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
         else if (arg == "-cmd" || arg == "--commands")      { params.commands      = argv[++i]; }
+        else if (arg == "-p"   || arg == "--prompt")        { params.prompt        = argv[++i]; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -103,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME,   --model FNAME    [%-7s] model path\n",                                  params.model.c_str());
     fprintf(stderr, "  -f FNAME,   --file FNAME     [%-7s] text output file name\n",                       params.fname_out.c_str());
     fprintf(stderr, "  -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n",             params.commands.c_str());
+    fprintf(stderr, "  -p,         --prompt         [%-7s] the required activation prompt\n",              params.prompt.c_str());
     fprintf(stderr, "\n");
 }
 
@@ -837,6 +842,115 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
    return 0;
 }
 
+
+// always prompt mode
+// transcribe the voice into text after valid prompt
+int always_prompt_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
+   bool is_running  = true;
+   bool ask_prompt  = true;
+
+   float prob  = 0.0f;
+
+   std::vector<float> pcmf32_cur;
+
+   const std::string k_prompt = params.prompt;
+
+   std::vector<std::string> words;
+
+   std::istringstream iss(k_prompt);
+   std::string word;
+
+   while (iss >> word) {
+       words.push_back(word);
+   }
+
+   int k_prompt_length = words.size();
+
+   // main loop
+   while (is_running) {
+      // handle Ctrl + C
+      {
+         SDL_Event event;
+         while (SDL_PollEvent(&event)) {
+            switch (event.type) {
+               case SDL_QUIT:
+               {
+                  is_running = false;
+               } break;
+               default:
+                  break;
+            }
+         }
+
+         if (!is_running) {
+            return 0;
+         }
+      }
+
+      // delay
+      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+      if (ask_prompt) {
+         fprintf(stdout, "\n");
+         fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
+         fprintf(stdout, "\n");
+
+         ask_prompt = false;
+      }
+
+      {
+         audio.get(2000, pcmf32_cur);
+
+         if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
+            fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
+
+            int64_t t_ms = 0;
+
+            // detect the commands
+            audio.get(params.command_ms, pcmf32_cur);
+
+            const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+
+            std::istringstream iss(txt);
+            std::string word;
+            std::string prompt;
+            std::string command;
+            int i = 0;
+            int command_length = 0;
+            while (iss >> word) {
+                if (i == k_prompt_length - 1) {
+                    prompt += word + ' ';
+                    break;
+                }
+                prompt += word + ' ';
+                i++;
+            }
+            while (iss >> word) {
+             command += word + ' ';
+             command_length++;
+            }
+
+            const float sim = similarity(prompt, k_prompt);
+
+            //debug
+            //fprintf(stdout, "command size: %i\n", command_length); 
+
+
+            if ((sim > 0.7f) && (command_length >0)){
+                fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
+            }
+
+            fprintf(stdout, "\n");
+
+
+            audio.clear();
+         }
+      }
+   }
+
+   return 0;
+}
+
 int main(int argc, char ** argv) {
     whisper_params params;
 
@@ -892,6 +1006,8 @@ int main(int argc, char ** argv) {
 
     if (!params.commands.empty()) {
        ret_val = process_command_list(ctx, audio, params);
+    } else if (!params.prompt.empty()) {
+       ret_val = always_prompt_transcription(ctx, audio, params);
     } else {
        ret_val = process_general_transcription(ctx, audio, params);
     }