]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : language auto-detect (#59)
authorGeorgi Gerganov <redacted>
Sat, 17 Dec 2022 15:58:08 +0000 (17:58 +0200)
committerGeorgi Gerganov <redacted>
Sat, 17 Dec 2022 16:49:44 +0000 (18:49 +0200)
examples/main/main.cpp
whisper.cpp
whisper.h

index 042f8928d6955a0b3b6c23858725db87b1f319f1..9d889ebb3c0148478eb6b58af43c403af4bab31e 100644 (file)
@@ -154,7 +154,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -pc,      --print-colors   [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
     fprintf(stderr, "  -pp,      --print-progress [%-7s] print progress\n",                                 params.print_progress ? "true" : "false");
     fprintf(stderr, "  -nt,      --no-timestamps  [%-7s] do not print timestamps\n",                        params.no_timestamps ? "false" : "true");
-    fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language\n",                                params.language.c_str());
+    fprintf(stderr, "  -l LANG,  --language LANG  [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
     fprintf(stderr, "            --prompt PROMPT  [%-7s] initial prompt\n",                                 params.prompt.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME, --file FNAME     [%-7s] input WAV file path\n",                            "");
@@ -453,7 +453,7 @@ int main(int argc, char ** argv) {
         return 2;
     }
 
-    if (whisper_lang_id(params.language.c_str()) == -1) {
+    if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
         fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
         whisper_print_usage(argc, argv, params);
         exit(0);
index 0aca60ccbb9e59d52d3b825c5e68c43964c256b2..95bbcdde3fade4ef4184bb6e8c4ca90b57dd773a 100644 (file)
@@ -1105,7 +1105,7 @@ static bool whisper_encode(
 
     struct ggml_init_params params;
     params.mem_size   = wctx.buf_compute.size();
-    params.mem_buffer = wctx.buf_compute.data();   
+    params.mem_buffer = wctx.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -2372,8 +2372,23 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
     return res.size();
 }
 
+int whisper_lang_max_id() {
+    auto max_id = 0;
+    for (const auto & kv : g_lang) {
+        max_id = std::max(max_id, kv.second.first);
+    }
+
+    return max_id;
+}
+
 int whisper_lang_id(const char * lang) {
     if (!g_lang.count(lang)) {
+        for (const auto & kv : g_lang) {
+            if (kv.second.second == lang) {
+                return kv.second.first;
+            }
+        }
+
         fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
         return -1;
     }
@@ -2381,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
     return g_lang.at(lang).first;
 }
 
+const char * whisper_lang_str(int id) {
+    for (const auto & kv : g_lang) {
+        if (kv.second.first == id) {
+            return kv.first.c_str();
+        }
+    }
+
+    fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
+    return NULL;
+}
+
+int whisper_lang_auto_detect(
+        struct whisper_context * ctx,
+        int offset_ms,
+        int n_threads,
+        float * lang_probs) {
+    const int seek = offset_ms/10;
+
+    if (seek < 0) {
+        fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+        return -1;
+    }
+
+    if (seek >= ctx->mel.n_len) {
+        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
+        return -2;
+    }
+
+    // run the encoder
+    if (whisper_encode(ctx, seek, n_threads) != 0) {
+        fprintf(stderr, "%s: failed to encode\n", __func__);
+        return -6;
+    }
+
+    const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
+
+    if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+        fprintf(stderr, "%s: failed to decode\n", __func__);
+        return -7;
+    }
+
+    std::vector<std::pair<float, int>> probs_id;
+    for (const auto kv : g_lang) {
+        const auto token_lang = whisper_token_lang(ctx, kv.second.first);
+        probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
+    }
+
+    // sort descending
+    {
+        using pair_type = decltype(probs_id)::value_type;
+        std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+            return a.first > b.first;
+        });
+    }
+
+    // softmax
+    {
+        float sum = 0;
+        for (const auto & kv : probs_id) {
+            sum += exp(kv.first);
+        }
+
+        for (auto & kv : probs_id) {
+            kv.first = exp(kv.first) / sum;
+        }
+    }
+
+    {
+        for (int i = 0; i < probs_id.size(); i++) {
+            if (lang_probs) {
+                lang_probs[probs_id[i].second] = probs_id[i].first;
+            }
+
+            //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
+        }
+    }
+
+    return probs_id[0].second;
+}
+
 int whisper_n_len(struct whisper_context * ctx) {
     return ctx->mel.n_len;
 }
@@ -2429,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
     return ctx->vocab.token_beg;
 }
 
+whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
+    return whisper_token_sot(ctx) + 1 + lang_id;
+}
+
 whisper_token whisper_token_translate(void) {
     return whisper_vocab::token_translate;
 }
@@ -2661,10 +2760,25 @@ int whisper_full(
     } else {
         if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-            return -1;
+            return -2;
         }
     }
 
+    // auto-detect language if not specified
+    if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+        std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+        if (lang_id < 0) {
+            fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+            return -3;
+        }
+
+        params.language = whisper_lang_str(lang_id);
+
+        fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+    }
+
     if (params.token_timestamps) {
         ctx->t_beg = 0;
         ctx->t_last = 0;
@@ -2703,7 +2817,8 @@ int whisper_full(
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
-        prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
+        const int lang_id = whisper_lang_id(params.language);
+        prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
         } else {
@@ -2752,7 +2867,7 @@ int whisper_full(
         // encode audio features starting at offset seek
         if (whisper_encode(ctx, seek, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
-            return 7;
+            return -4;
         }
 
         int n_past = 0;
@@ -2790,7 +2905,7 @@ int whisper_full(
         for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
                 fprintf(stderr, "%s: failed to decode\n", __func__);
-                return 8;
+                return -5;
             }
 
             n_past += prompt.size();
index a28a3b7e70408627139e2cbeb1988ab4d3dcf593..e2657c1b5c3337bf7d3eafe3c848f4eb8fdb7097 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -150,9 +150,30 @@ extern "C" {
                      whisper_token * tokens,
                                int   n_max_tokens);
 
+    // Largest language id (i.e. number of available languages - 1)
+    WHISPER_API int whisper_lang_max_id();
+
     // Return the id of the specified language, returns -1 if not found
+    // Examples:
+    //   "de" -> 2
+    //   "german" -> 2
     WHISPER_API int whisper_lang_id(const char * lang);
 
+    // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
+    WHISPER_API const char * whisper_lang_str(int id);
+
+    // Use mel data at offset_ms to try and auto-detect the spoken language
+    // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
+    // Returns the top language id or negative on failure
+    // If not null, fills the lang_probs array with the probabilities of all languages
+    // The array must be whispe_lang_max_id() + 1 in size
+    // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+    WHISPER_API int whisper_lang_auto_detect(
+            struct whisper_context * ctx,
+                               int   offset_ms,
+                               int   n_threads,
+                             float * lang_probs);
+
     WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
     WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
     WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
@@ -171,6 +192,7 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
+    WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
 
     // Task tokens
     WHISPER_API whisper_token whisper_token_translate (void);
@@ -236,6 +258,7 @@ extern "C" {
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;
 
+        // for auto-detection, set to nullptr, "" or "auto"
         const char * language;
 
         struct {