]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
command : clean-up / refactoring / formatting (#383)
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 19:43:24 +0000 (21:43 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 19:43:24 +0000 (21:43 +0200)
examples/command/command.cpp

index 524ad67f7f6bfb5cabb79323325c52a33148383e..74a14f9cc867d3542dd95d16b8bda0c8d7f41b52 100644 (file)
@@ -11,7 +11,6 @@
 #include <SDL.h>
 #include <SDL_audio.h>
 
-#include <iostream>
 #include <sstream>
 #include <cassert>
 #include <cstdio>
@@ -515,440 +514,406 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
     return allowed_commands;
 }
 
+std::vector<std::string> get_words(const std::string &txt) {
+    std::vector<std::string> words;
+
+    std::istringstream iss(txt);
+    std::string word;
+    while (iss >> word) {
+        words.push_back(word);
+    }
+
+    return words;
+}
+
+// returns true if no exit event was received
+bool process_sdl_events() {
+    SDL_Event event;
+    while (SDL_PollEvent(&event)) {
+        switch (event.type) {
+            case SDL_QUIT:
+                {
+                    return false;
+                } break;
+            default:
+                break;
+        }
+    }
+
+    return true;
+}
+
 // command-list mode
 // guide the transcription to match the most likely command from a provided list
 int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
-   fprintf(stderr, "\n");
-   fprintf(stderr, "%s: guided mode\n", __func__);
-
-   std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
-
-   if (allowed_commands.empty()) {
-      fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
-      return 2;
-   }
-
-   int max_len = 0;
-
-   std::vector<std::vector<whisper_token>> allowed_tokens;
-
-   for (const auto & cmd : allowed_commands) {
-      whisper_token tokens[1024];
-      allowed_tokens.emplace_back();
-
-      for (int l = 0; l < (int) cmd.size(); ++l) {
-         // NOTE: very important to add the whitespace !
-         //       the reason is that the first decoded token starts with a whitespace too!
-         std::string ss = std::string(" ") + cmd.substr(0, l + 1);
-
-         const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
-         if (n < 0) {
-            fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
-            return 3;
-         }
-
-         if (n == 1) {
-            allowed_tokens.back().push_back(tokens[0]);
-         }
-      }
-
-      max_len = std::max(max_len, (int) cmd.size());
-   }
-
-   fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
-   fprintf(stderr, "\n");
-   for (int i = 0; i < (int) allowed_commands.size(); ++i) {
-      fprintf(stderr, "  - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
-      for (const auto & token : allowed_tokens[i]) {
-         fprintf(stderr, " %5d", token);
-      }
-      fprintf(stderr, " ]\n");
-   }
-
-   std::string  k_prompt = "select one from the available words: ";
-   for (int i = 0; i < (int) allowed_commands.size(); ++i) {
-      if (i > 0) {
-         k_prompt += ", ";
-      }
-      k_prompt += allowed_commands[i];
-   }
-   k_prompt += ". selected word: ";
-
-   // tokenize prompt
-   std::vector<whisper_token> k_tokens;
-   {
-      k_tokens.resize(1024);
-      const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
-      if (n < 0) {
-         fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
-         return 4;
-      }
-      k_tokens.resize(n);
-   }
-
-   fprintf(stderr, "\n");
-   fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
-   fprintf(stderr, "%s: tokens: [", __func__);
-   for (const auto & token : k_tokens) {
-      fprintf(stderr, " %d", token);
-   }
-   fprintf(stderr, " ]\n");
-
-   fprintf(stderr, "\n");
-   fprintf(stderr, "%s: listening for a command ...\n", __func__);
-   fprintf(stderr, "\n");
-
-   bool is_running  = true;
-
-   std::vector<float> pcmf32_cur;
-   std::vector<float> pcmf32_prompt;
-
-   // main loop
-   while (is_running) {
-      // handle Ctrl + C
-      {
-         SDL_Event event;
-         while (SDL_PollEvent(&event)) {
-            switch (event.type) {
-               case SDL_QUIT:
-               {
-                  is_running = false;
-               } break;
-               default:
-                  break;
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: guided mode\n", __func__);
+
+    std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
+
+    if (allowed_commands.empty()) {
+        fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
+        return 2;
+    }
+
+    int max_len = 0;
+
+    std::vector<std::vector<whisper_token>> allowed_tokens;
+
+    for (const auto & cmd : allowed_commands) {
+        whisper_token tokens[1024];
+        allowed_tokens.emplace_back();
+
+        for (int l = 0; l < (int) cmd.size(); ++l) {
+            // NOTE: very important to add the whitespace !
+            //       the reason is that the first decoded token starts with a whitespace too!
+            std::string ss = std::string(" ") + cmd.substr(0, l + 1);
+
+            const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
+            if (n < 0) {
+                fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
+                return 3;
             }
-         }
 
-         if (!is_running) {
-            return 0;
-         }
-      }
+            if (n == 1) {
+                allowed_tokens.back().push_back(tokens[0]);
+            }
+        }
 
-      // delay
-      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+        max_len = std::max(max_len, (int) cmd.size());
+    }
 
-      audio.get(2000, pcmf32_cur);
+    fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
+    fprintf(stderr, "\n");
+    for (int i = 0; i < (int) allowed_commands.size(); ++i) {
+        fprintf(stderr, "  - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
+        for (const auto & token : allowed_tokens[i]) {
+            fprintf(stderr, " %5d", token);
+        }
+        fprintf(stderr, " ]\n");
+    }
+
+    std::string  k_prompt = "select one from the available words: ";
+    for (int i = 0; i < (int) allowed_commands.size(); ++i) {
+        if (i > 0) {
+            k_prompt += ", ";
+        }
+        k_prompt += allowed_commands[i];
+    }
+    k_prompt += ". selected word: ";
+
+    // tokenize prompt
+    std::vector<whisper_token> k_tokens;
+    {
+        k_tokens.resize(1024);
+        const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
+        if (n < 0) {
+            fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
+            return 4;
+        }
+        k_tokens.resize(n);
+    }
+
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
+    fprintf(stderr, "%s: tokens: [", __func__);
+    for (const auto & token : k_tokens) {
+        fprintf(stderr, " %d", token);
+    }
+    fprintf(stderr, " ]\n");
+
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: listening for a command ...\n", __func__);
+    fprintf(stderr, "\n");
+
+    bool is_running  = true;
 
-      if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
-         fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
+    std::vector<float> pcmf32_cur;
+    std::vector<float> pcmf32_prompt;
 
-         const auto t_start = std::chrono::high_resolution_clock::now();
+    // main loop
+    while (is_running) {
+        // handle Ctrl + C
+        is_running = process_sdl_events();
 
-         whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+        // delay
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        audio.get(2000, pcmf32_cur);
+
+        if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
+            fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
 
-         wparams.print_progress   = false;
-         wparams.print_special    = params.print_special;
-         wparams.print_realtime   = false;
-         wparams.print_timestamps = !params.no_timestamps;
-         wparams.translate        = params.translate;
-         wparams.no_context       = true;
-         wparams.single_segment   = true;
-         wparams.max_tokens       = 1;
-         wparams.language         = params.language.c_str();
-         wparams.n_threads        = params.n_threads;
+            const auto t_start = std::chrono::high_resolution_clock::now();
 
-         wparams.audio_ctx        = params.audio_ctx;
-         wparams.speed_up         = params.speed_up;
+            whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
-         wparams.prompt_tokens    = k_tokens.data();
-         wparams.prompt_n_tokens  = k_tokens.size();
+            wparams.print_progress   = false;
+            wparams.print_special    = params.print_special;
+            wparams.print_realtime   = false;
+            wparams.print_timestamps = !params.no_timestamps;
+            wparams.translate        = params.translate;
+            wparams.no_context       = true;
+            wparams.single_segment   = true;
+            wparams.max_tokens       = 1;
+            wparams.language         = params.language.c_str();
+            wparams.n_threads        = params.n_threads;
 
-         // run the transformer and a single decoding pass
-         if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
-            fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
-            break;
-         }
+            wparams.audio_ctx        = params.audio_ctx;
+            wparams.speed_up         = params.speed_up;
 
-         const auto * probs = whisper_get_probs(ctx);
-         std::vector<std::pair<float, int>> probs_id;
+            wparams.prompt_tokens    = k_tokens.data();
+            wparams.prompt_n_tokens  = k_tokens.size();
 
-         double psum = 0.0;
-         for (int i = 0; i < (int) allowed_commands.size(); ++i) {
-            probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
-            for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
-               probs_id.back().first += probs[allowed_tokens[i][j]];
+            // run the transformer and a single decoding pass
+            if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
+                fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
+                break;
             }
-            probs_id.back().first /= allowed_tokens[i].size();
-            psum += probs_id.back().first;
-         }
-
-         // normalize
-         for (auto & p : probs_id) {
-            p.first /= psum;
-         }
-
-         // sort descending
-         {
-            using pair_type = decltype(probs_id)::value_type;
-            std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
-               return a.first > b.first;
-            });
-         }
-
-         // print the commands and the respective probabilities
-         {
-            fprintf(stdout, "\n");
-            for (const auto & cmd : probs_id) {
-               fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
-               for (int token : allowed_tokens[cmd.second]) {
-                  fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
-               }
-               fprintf(stdout, "\n");
+
+            const auto * probs = whisper_get_probs(ctx);
+            std::vector<std::pair<float, int>> probs_id;
+
+            double psum = 0.0;
+            for (int i = 0; i < (int) allowed_commands.size(); ++i) {
+                probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
+                for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
+                    probs_id.back().first += probs[allowed_tokens[i][j]];
+                }
+                probs_id.back().first /= allowed_tokens[i].size();
+                psum += probs_id.back().first;
             }
-         }
 
-         // best command
-         {
-            const auto t_end = std::chrono::high_resolution_clock::now();
+            // normalize
+            for (auto & p : probs_id) {
+                p.first /= psum;
+            }
 
-            const float prob = probs_id[0].first;
-            const int index = probs_id[0].second;
+            // sort descending
+            {
+                using pair_type = decltype(probs_id)::value_type;
+                std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+                    return a.first > b.first;
+                });
+            }
 
-            fprintf(stdout, "\n");
-            fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
-                    "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
-                  (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
-            fprintf(stdout, "\n");
-         }
+            // print the commands and the respective probabilities
+            {
+                fprintf(stdout, "\n");
+                for (const auto & cmd : probs_id) {
+                    fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
+                    for (int token : allowed_tokens[cmd.second]) {
+                        fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
+                    }
+                    fprintf(stdout, "\n");
+                }
+            }
 
-         audio.clear();
-      }
-   }
+            // best command
+            {
+                const auto t_end = std::chrono::high_resolution_clock::now();
 
-   return 0;
-}
+                const float prob = probs_id[0].first;
+                const int index = probs_id[0].second;
 
-// general-purpose mode
-// freely transcribe the voice into text
-int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
-   bool is_running  = true;
-   bool have_prompt = false;
-   bool ask_prompt  = true;
-
-   float prob0 = 0.0f;
-   float prob  = 0.0f;
-
-   std::vector<float> pcmf32_cur;
-   std::vector<float> pcmf32_prompt;
-
-   const std::string k_prompt = "Ok Whisper, start listening for commands.";
-
-   fprintf(stderr, "\n");
-   fprintf(stderr, "%s: general-purpose mode\n", __func__);
-
-   // main loop
-   while (is_running) {
-      // handle Ctrl + C
-      {
-         SDL_Event event;
-         while (SDL_PollEvent(&event)) {
-            switch (event.type) {
-               case SDL_QUIT:
-               {
-                  is_running = false;
-               } break;
-               default:
-                  break;
+                fprintf(stdout, "\n");
+                fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
+                        "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
+                        (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
+                fprintf(stdout, "\n");
             }
-         }
 
-         if (!is_running) {
-            return 0;
-         }
-      }
+            audio.clear();
+        }
+    }
+
+    return 0;
+}
 
-      // delay
-      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+// always-prompt mode
+// transcribe the voice into text after valid prompt
+int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
+    bool is_running = true;
+    bool ask_prompt = true;
 
-      if (ask_prompt) {
-         fprintf(stdout, "\n");
-         fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
-         fprintf(stdout, "\n");
+    float prob = 0.0f;
 
-         ask_prompt = false;
-      }
+    std::vector<float> pcmf32_cur;
 
-      {
-         audio.get(2000, pcmf32_cur);
+    const std::string k_prompt = params.prompt;
 
-         if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
-            fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
+    const int k_prompt_length = get_words(k_prompt).size();
 
-            int64_t t_ms = 0;
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: always-prompt mode\n", __func__);
+
+    // main loop
+    while (is_running) {
+        // handle Ctrl + C
+        is_running = process_sdl_events();
 
-            if (!have_prompt) {
-               // wait for activation phrase
-               audio.get(params.prompt_ms, pcmf32_cur);
+        // delay
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
-               const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
+        if (ask_prompt) {
+            fprintf(stdout, "\n");
+            fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
+            fprintf(stdout, "\n");
 
-               fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
+            ask_prompt = false;
+        }
 
-               const float sim = similarity(txt, k_prompt);
+        {
+            audio.get(2000, pcmf32_cur);
 
-               if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
-                  fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
-                  ask_prompt = true;
-               } else {
-                  fprintf(stdout, "\n");
-                  fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
-                  fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
-                  fprintf(stdout, "\n");
+            if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
+                fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
 
-                  // save the audio for the prompt
-                  pcmf32_prompt = pcmf32_cur;
-                  have_prompt = true;
-               }
-            } else {
-               // we have heard the activation phrase, now detect the commands
-               audio.get(params.command_ms, pcmf32_cur);
+                int64_t t_ms = 0;
 
-               // prepend the prompt audio
-               pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
+                // detect the commands
+                audio.get(params.command_ms, pcmf32_cur);
 
-               const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+                const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
 
-               prob = 100.0f*(prob - prob0);
+                const auto words = get_words(txt);
 
-               //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
+                std::string prompt;
+                std::string command;
 
-               // find the prompt in the text
-               float best_sim = 0.0f;
-               size_t best_len = 0;
-               for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
-                  const auto prompt = txt.substr(0, n);
+                for (int i = 0; i < words.size(); ++i) {
+                    if (i < k_prompt_length) {
+                        prompt += words[i] + " ";
+                    } else {
+                        command += words[i] + " ";
+                    }
+                }
 
-                  const float sim = similarity(prompt, k_prompt);
+                const float sim = similarity(prompt, k_prompt);
 
-                  //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
+                //debug
+                //fprintf(stdout, "command size: %i\n", command_length);
 
-                  if (sim > best_sim) {
-                     best_sim = sim;
-                     best_len = n;
-                  }
-               }
+                if ((sim > 0.7f) && (command.size() > 0)) {
+                    fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
+                }
 
-               const std::string command = ::trim(txt.substr(best_len));
+                fprintf(stdout, "\n");
 
-               fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
-               fprintf(stdout, "\n");
+                audio.clear();
             }
+        }
+    }
 
-            audio.clear();
-         }
-      }
-   }
-
-   return 0;
+    return 0;
 }
 
+// general-purpose mode
+// freely transcribe the voice into text
+int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
+    bool is_running  = true;
+    bool have_prompt = false;
+    bool ask_prompt  = true;
 
-// always prompt mode
-// transcribe the voice into text after valid prompt
-int always_prompt_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
-   bool is_running  = true;
-   bool ask_prompt  = true;
+    float prob0 = 0.0f;
+    float prob  = 0.0f;
 
-   float prob  = 0.0f;
+    std::vector<float> pcmf32_cur;
+    std::vector<float> pcmf32_prompt;
 
-   std::vector<float> pcmf32_cur;
+    const std::string k_prompt = "Ok Whisper, start listening for commands.";
 
-   const std::string k_prompt = params.prompt;
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s: general-purpose mode\n", __func__);
 
-   std::vector<std::string> words;
+    // main loop
+    while (is_running) {
+        // handle Ctrl + C
+        is_running = process_sdl_events();
 
-   std::istringstream iss(k_prompt);
-   std::string word;
+        // delay
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
-   while (iss >> word) {
-       words.push_back(word);
-   }
+        if (ask_prompt) {
+            fprintf(stdout, "\n");
+            fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
+            fprintf(stdout, "\n");
 
-   int k_prompt_length = words.size();
+            ask_prompt = false;
+        }
 
-   // main loop
-   while (is_running) {
-      // handle Ctrl + C
-      {
-         SDL_Event event;
-         while (SDL_PollEvent(&event)) {
-            switch (event.type) {
-               case SDL_QUIT:
-               {
-                  is_running = false;
-               } break;
-               default:
-                  break;
-            }
-         }
+        {
+            audio.get(2000, pcmf32_cur);
 
-         if (!is_running) {
-            return 0;
-         }
-      }
+            if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
+                fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
 
-      // delay
-      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+                int64_t t_ms = 0;
 
-      if (ask_prompt) {
-         fprintf(stdout, "\n");
-         fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
-         fprintf(stdout, "\n");
+                if (!have_prompt) {
+                    // wait for activation phrase
+                    audio.get(params.prompt_ms, pcmf32_cur);
 
-         ask_prompt = false;
-      }
+                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
 
-      {
-         audio.get(2000, pcmf32_cur);
+                    fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
 
-         if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
-            fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
+                    const float sim = similarity(txt, k_prompt);
 
-            int64_t t_ms = 0;
+                    if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
+                        fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
+                        ask_prompt = true;
+                    } else {
+                        fprintf(stdout, "\n");
+                        fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
+                        fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
+                        fprintf(stdout, "\n");
 
-            // detect the commands
-            audio.get(params.command_ms, pcmf32_cur);
+                        // save the audio for the prompt
+                        pcmf32_prompt = pcmf32_cur;
+                        have_prompt = true;
+                    }
+                } else {
+                    // we have heard the activation phrase, now detect the commands
+                    audio.get(params.command_ms, pcmf32_cur);
 
-            const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
+                    // prepend the prompt audio
+                    pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
 
-            std::istringstream iss(txt);
-            std::string word;
-            std::string prompt;
-            std::string command;
-            int i = 0;
-            int command_length = 0;
-            while (iss >> word) {
-                if (i == k_prompt_length - 1) {
-                    prompt += word + ' ';
-                    break;
-                }
-                prompt += word + ' ';
-                i++;
-            }
-            while (iss >> word) {
-             command += word + ' ';
-             command_length++;
-            }
+                    const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
 
-            const float sim = similarity(prompt, k_prompt);
+                    prob = 100.0f*(prob - prob0);
 
-            //debug
-            //fprintf(stdout, "command size: %i\n", command_length); 
+                    //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
 
+                    // find the prompt in the text
+                    float best_sim = 0.0f;
+                    size_t best_len = 0;
+                    for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
+                        const auto prompt = txt.substr(0, n);
 
-            if ((sim > 0.7f) && (command_length >0)){
-                fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
-            }
+                        const float sim = similarity(prompt, k_prompt);
 
-            fprintf(stdout, "\n");
+                        //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
 
+                        if (sim > best_sim) {
+                            best_sim = sim;
+                            best_len = n;
+                        }
+                    }
 
-            audio.clear();
-         }
-      }
-   }
+                    const std::string command = ::trim(txt.substr(best_len));
+
+                    fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
+                    fprintf(stdout, "\n");
+                }
+
+                audio.clear();
+            }
+        }
+    }
 
-   return 0;
+    return 0;
 }
 
 int main(int argc, char ** argv) {
@@ -1005,11 +970,11 @@ int main(int argc, char ** argv) {
     int  ret_val = 0;
 
     if (!params.commands.empty()) {
-       ret_val = process_command_list(ctx, audio, params);
+        ret_val = process_command_list(ctx, audio, params);
     } else if (!params.prompt.empty()) {
-       ret_val = always_prompt_transcription(ctx, audio, params);
+        ret_val = always_prompt_transcription(ctx, audio, params);
     } else {
-       ret_val = process_general_transcription(ctx, audio, params);
+        ret_val = process_general_transcription(ctx, audio, params);
     }
 
     audio.pause();