]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add whisper_full_lang_id() for getting the context lang (#461)
authorkamranjon <redacted>
Sun, 5 Feb 2023 12:46:26 +0000 (04:46 -0800)
committerGitHub <redacted>
Sun, 5 Feb 2023 12:46:26 +0000 (14:46 +0200)
whisper.cpp
whisper.h

index 1a4a207157b5a9c5b119437582f2cff7956fe877..aedd343ae6fc65f419506fe2b4ebef27b3e5213b 100644 (file)
@@ -592,6 +592,8 @@ struct whisper_context {
 
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 
+    int lang_id;
+
     // [EXPERIMENTAL] token-level timestamps data
     int64_t t_beg;
     int64_t t_last;
@@ -3478,7 +3480,7 @@ int whisper_full(
             fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
             return -3;
         }
-
+        ctx->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
         fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
@@ -3575,6 +3577,7 @@ int whisper_full(
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
+        ctx->lang_id = lang_id;
         prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
@@ -4295,6 +4298,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
     return ctx->result_all.size();
 }
 
+int whisper_full_lang_id(struct whisper_context * ctx) {
+    return ctx->lang_id; 
+}
+
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
     return ctx->result_all[i_segment].t0;
 }
index 3a426680d2c7b494c6d9cadfe897d30d9dbe258e..72331e6abd4d1a725e660f91ce4d647a4e1cd02e 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -330,6 +330,9 @@ extern "C" {
     // A segment can be a few words, a sentence, or even a paragraph.
     WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
 
+    // Language id associated with the current context
+    WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
+
     // Get the start and end time of the specified segment.
     WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
     WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);