]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : optional wake-up command and audio confirmation (#1765)
authorBenjamin Heiniger <redacted>
Tue, 16 Jan 2024 13:52:01 +0000 (14:52 +0100)
committerGitHub <redacted>
Tue, 16 Jan 2024 13:52:01 +0000 (15:52 +0200)
* talk-llama: add optional wake-word detection from command

* talk-llama: add optional audio confirmation before generating answer

* talk-llama: fix small formatting issue in output

* talk-llama.cpp: fix Windows build

examples/talk-llama/talk-llama.cpp

index 5eef1f4e619795f7a38d84b0d416e001ed3cf463..d418d0c32fc605e8c4905873c343d8d60c91fc72 100644 (file)
@@ -14,6 +14,7 @@
 #include <thread>
 #include <vector>
 #include <regex>
+#include <sstream>
 
 std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
     auto * model = llama_get_model(ctx);
@@ -68,6 +69,8 @@ struct whisper_params {
 
     std::string person      = "Georgi";
     std::string bot_name    = "LLaMA";
+    std::string wake_cmd    = "";
+    std::string heard_ok    = "";
     std::string language    = "en";
     std::string model_wsp   = "models/ggml-base.en.bin";
     std::string model_llama = "models/ggml-llama-7B.bin";
@@ -104,6 +107,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-p"   || arg == "--person")         { params.person         = argv[++i]; }
         else if (arg == "-bn"   || arg == "--bot-name")      { params.bot_name       = argv[++i]; }
         else if (arg == "--session")                         { params.path_session   = argv[++i]; }
+        else if (arg == "-w"   || arg == "--wake-command")   { params.wake_cmd       = argv[++i]; }
+        else if (arg == "-ho"  || arg == "--heard-ok")       { params.heard_ok       = argv[++i]; }
         else if (arg == "-l"   || arg == "--language")       { params.language       = argv[++i]; }
         else if (arg == "-mw"  || arg == "--model-whisper")  { params.model_wsp      = argv[++i]; }
         else if (arg == "-ml"  || arg == "--model-llama")    { params.model_llama    = argv[++i]; }
@@ -149,6 +154,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ng,      --no-gpu         [%-7s] disable GPU\n",                                 params.use_gpu ? "false" : "true");
     fprintf(stderr, "  -p NAME,  --person NAME    [%-7s] person name (for prompt selection)\n",          params.person.c_str());
     fprintf(stderr, "  -bn NAME, --bot-name NAME  [%-7s] bot name (to display)\n",                       params.bot_name.c_str());
+    fprintf(stderr, "  -w TEXT,  --wake-command T [%-7s] wake-up command to listen for\n",               params.wake_cmd.c_str());
+    fprintf(stderr, "  -ho TEXT, --heard-ok TEXT  [%-7s] said by TTS before generating reply\n",         params.heard_ok.c_str());
     fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language\n",                             params.language.c_str());
     fprintf(stderr, "  -mw FILE, --model-whisper  [%-7s] whisper model file\n",                          params.model_wsp.c_str());
     fprintf(stderr, "  -ml FILE, --model-llama    [%-7s] llama model file\n",                            params.model_llama.c_str());
@@ -227,6 +234,18 @@ std::string transcribe(
     return result;
 }
 
+std::vector<std::string> get_words(const std::string &txt) {
+    std::vector<std::string> words;
+
+    std::istringstream iss(txt);
+    std::string word;
+    while (iss >> word) {
+        words.push_back(word);
+    }
+
+    return words;
+}
+
 const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
 
 const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
@@ -441,6 +460,16 @@ int main(int argc, char ** argv) {
     bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
 
     printf("%s : done! start speaking in the microphone\n", __func__);
+
+    // show wake command if enabled
+    const std::string wake_cmd = params.wake_cmd;
+    const int wake_cmd_length = get_words(wake_cmd).size();
+    const bool use_wake_cmd = wake_cmd_length > 0;
+
+    if (use_wake_cmd) {
+        printf("%s : the wake-up command is: '%s%s%s'\n", __func__, "\033[1m", wake_cmd.c_str(), "\033[0m");
+    }
+
     printf("\n");
     printf("%s%s", params.person.c_str(), chat_symb.c_str());
     fflush(stdout);
@@ -486,10 +515,41 @@ int main(int argc, char ** argv) {
 
                 audio.get(params.voice_ms, pcmf32_cur);
 
-                std::string text_heard;
+                std::string all_heard;
 
                 if (!force_speak) {
-                    text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
+                    all_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
+                }
+
+                const auto words = get_words(all_heard);
+
+                std::string wake_cmd_heard;
+                std::string text_heard;
+
+                for (int i = 0; i < (int) words.size(); ++i) {
+                    if (i < wake_cmd_length) {
+                        wake_cmd_heard += words[i] + " ";
+                    } else {
+                        text_heard += words[i] + " ";
+                    }
+                }
+
+                // check if audio starts with the wake-up command if enabled
+                if (use_wake_cmd) {
+                    const float sim = similarity(wake_cmd_heard, wake_cmd);
+
+                    if ((sim < 0.7f) || (text_heard.empty())) {
+                        audio.clear();
+                        continue;
+                    }
+                }
+
+                // optionally give audio feedback that the current text is being processed
+                if (!params.heard_ok.empty()) {
+                    int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + params.heard_ok + "'").c_str());
+                    if (ret != 0) {
+                        fprintf(stderr, "%s: failed to speak\n", __func__);
+                    }
                 }
 
                 // remove text between brackets using regex