]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add detect-language mode (#853)
authorCRD716 <redacted>
Tue, 2 May 2023 16:51:52 +0000 (11:51 -0500)
committerGitHub <redacted>
Tue, 2 May 2023 16:51:52 +0000 (19:51 +0300)
* add detectlanguage flag

* renaming and help

* no idea why that last one didn't commit

* run language detection if dl is set

* help message fix

* various fixes

* fix quitting

* fix language being english on print

examples/main/main.cpp
whisper.cpp
whisper.h

index 3e8c5aaa1bb02ba5810c1a2260b2c371e6441e48..c6bf32ed8b81f2e508e93d9fd56e73b6cfe3fd4a 100644 (file)
@@ -66,6 +66,7 @@ struct whisper_params {
 
     bool speed_up       = false;
     bool translate      = false;
+    bool detect_language= false;
     bool diarize        = false;
     bool split_on_word  = false;
     bool no_fallback    = false;
@@ -141,6 +142,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-pp"   || arg == "--print-progress") { params.print_progress = true; }
         else if (arg == "-nt"   || arg == "--no-timestamps")  { params.no_timestamps  = true; }
         else if (arg == "-l"    || arg == "--language")       { params.language       = argv[++i]; }
+        else if (arg == "-dl"   || arg == "--detect-language"){ params.detect_language= true; }
         else if (                  arg == "--prompt")         { params.prompt         = argv[++i]; }
         else if (arg == "-m"    || arg == "--model")          { params.model          = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")           { params.fname_inp.emplace_back(argv[++i]); }
@@ -191,6 +193,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -pp,       --print-progress    [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
     fprintf(stderr, "  -nt,       --no-timestamps     [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
     fprintf(stderr, "  -l LANG,   --language LANG     [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
+    fprintf(stderr, "  -dl,       --detect-language   [%-7s] exit after automatically detecting language\n",    params.detect_language ? "true" : "false");
     fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt\n",                                 params.prompt.c_str());
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
@@ -739,6 +742,9 @@ int main(int argc, char ** argv) {
                     fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
                 }
             }
+            if (params.detect_language) {
+                params.language = "auto";
+            }
             fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
                     __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
                     params.n_threads, params.n_processors,
@@ -761,6 +767,7 @@ int main(int argc, char ** argv) {
             wparams.print_special    = params.print_special;
             wparams.translate        = params.translate;
             wparams.language         = params.language.c_str();
+            wparams.detect_language  = params.detect_language;
             wparams.n_threads        = params.n_threads;
             wparams.n_max_text_ctx   = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
             wparams.offset_ms        = params.offset_t_ms;
index df283ec9019e2d687d7fc3d0a47df0ca5e96a970..158aa0b9881bb8a1ab249e646f7f87932c926831 100644 (file)
@@ -3312,6 +3312,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.prompt_n_tokens  =*/ 0,
 
         /*.language         =*/ "en",
+        /*.detect_language  =*/ false,
 
         /*.suppress_blank   =*/ true,
         /*.suppress_non_speech_tokens =*/ false,
@@ -3898,7 +3899,7 @@ int whisper_full_with_state(
     }
 
     // auto-detect language if not specified
-    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
         std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
 
         const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
@@ -3910,6 +3911,9 @@ int whisper_full_with_state(
         params.language = whisper_lang_str(lang_id);
 
         fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+        if (params.detect_language) {
+            return 0;
+        }
     }
 
     if (params.token_timestamps) {
index 3d689a4c924ce3f2c3d35c18f940465d59b39c87..2d5b3eb98579811e86e531d8bbab9b2d75403d9e 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -365,6 +365,7 @@ extern "C" {
 
         // for auto-detection, set to nullptr, "" or "auto"
         const char * language;
+        bool detect_language;
 
         // common decoding parameters:
         bool suppress_blank;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89