]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : sync whisper.cpp (tinydiarize + OpenVINO)
authorGeorgi Gerganov <redacted>
Tue, 4 Jul 2023 17:24:22 +0000 (20:24 +0300)
committerGeorgi Gerganov <redacted>
Tue, 4 Jul 2023 17:24:22 +0000 (20:24 +0300)
examples/common.h
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index 74655cbfc6c9a1b2aa986421477c06c2d0e6c0a4..f9740a3c3d8e1c8fb989a74e184f5a0c2b40060c 100644 (file)
 //
 
 struct gpt_params {
-    int32_t seed      = -1; // RNG seed
+    int32_t seed      = -1;  // RNG seed
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict = 200; // new tokens to predict
+    int32_t n_batch   = 8;   // batch size for prompt processing
 
     // sampling parameters
-    int32_t top_k = 40;
-    float   top_p = 0.9f;
-    float   temp  = 0.9f;
+    int32_t top_k          = 40;
+    float   top_p          = 0.9f;
+    float   temp           = 0.9f;
     int32_t repeat_last_n  = 64;
     float   repeat_penalty = 1.00f;
 
-    int32_t n_batch = 8; // batch size for prompt processing
-
     std::string model      = "models/gpt-2-117M/ggml-model.bin"; // model path
     std::string prompt     = "";
     std::string token_test = "";
 
-    bool interactive = false;
-    int interactive_port = -1;
+    bool    interactive      = false;
+    int32_t interactive_port = -1;
 };
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
index ff62f74b88761ba1c7809caf233838957f5d34ac..8dd31d028b1e732bf6572230dab0fd8cc835d721 100644 (file)
@@ -68,28 +68,34 @@ struct whisper_params {
     float entropy_thold =  2.40f;
     float logprob_thold = -1.00f;
 
-    bool speed_up       = false;
-    bool translate      = false;
-    bool detect_language= false;
-    bool diarize        = false;
-    bool split_on_word  = false;
-    bool no_fallback    = false;
-    bool output_txt     = false;
-    bool output_vtt     = false;
-    bool output_srt     = false;
-    bool output_wts     = false;
-    bool output_csv     = false;
-    bool output_jsn     = false;
-    bool output_lrc     = false;
-    bool print_special  = false;
-    bool print_colors   = false;
-    bool print_progress = false;
-    bool no_timestamps  = false;
-
-    std::string language = "en";
+    bool speed_up        = false;
+    bool translate       = false;
+    bool detect_language = false;
+    bool diarize         = false;
+    bool tinydiarize     = false;
+    bool split_on_word   = false;
+    bool no_fallback     = false;
+    bool output_txt      = false;
+    bool output_vtt      = false;
+    bool output_srt      = false;
+    bool output_wts      = false;
+    bool output_csv      = false;
+    bool output_jsn      = false;
+    bool output_lrc      = false;
+    bool print_special   = false;
+    bool print_colors    = false;
+    bool print_progress  = false;
+    bool no_timestamps   = false;
+
+    std::string language  = "en";
     std::string prompt;
     std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
-    std::string model    = "models/ggml-base.en.bin";
+    std::string model     = "models/ggml-base.en.bin";
+
+    // [TDRZ] speaker turn string
+    std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
+
+    std::string openvino_encode_device = "CPU";
 
     std::vector<std::string> fname_inp = {};
     std::vector<std::string> fname_out = {};
@@ -115,41 +121,43 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             whisper_print_usage(argc, argv, params);
             exit(0);
         }
-        else if (arg == "-t"    || arg == "--threads")        { params.n_threads      = std::stoi(argv[++i]); }
-        else if (arg == "-p"    || arg == "--processors")     { params.n_processors   = std::stoi(argv[++i]); }
-        else if (arg == "-ot"   || arg == "--offset-t")       { params.offset_t_ms    = std::stoi(argv[++i]); }
-        else if (arg == "-on"   || arg == "--offset-n")       { params.offset_n       = std::stoi(argv[++i]); }
-        else if (arg == "-d"    || arg == "--duration")       { params.duration_ms    = std::stoi(argv[++i]); }
-        else if (arg == "-mc"   || arg == "--max-context")    { params.max_context    = std::stoi(argv[++i]); }
-        else if (arg == "-ml"   || arg == "--max-len")        { params.max_len        = std::stoi(argv[++i]); }
-        else if (arg == "-bo"   || arg == "--best-of")        { params.best_of        = std::stoi(argv[++i]); }
-        else if (arg == "-bs"   || arg == "--beam-size")      { params.beam_size      = std::stoi(argv[++i]); }
-        else if (arg == "-wt"   || arg == "--word-thold")     { params.word_thold     = std::stof(argv[++i]); }
-        else if (arg == "-et"   || arg == "--entropy-thold")  { params.entropy_thold  = std::stof(argv[++i]); }
-        else if (arg == "-lpt"  || arg == "--logprob-thold")  { params.logprob_thold  = std::stof(argv[++i]); }
-        else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
-        else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
-        else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
-        else if (arg == "-sow"  || arg == "--split-on-word")  { params.split_on_word  = true; }
-        else if (arg == "-nf"   || arg == "--no-fallback")    { params.no_fallback    = true; }
-        else if (arg == "-otxt" || arg == "--output-txt")     { params.output_txt     = true; }
-        else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
-        else if (arg == "-osrt" || arg == "--output-srt")     { params.output_srt     = true; }
-        else if (arg == "-owts" || arg == "--output-words")   { params.output_wts     = true; }
-        else if (arg == "-olrc" || arg == "--output-lrc")     { params.output_lrc     = true; }
-        else if (arg == "-fp"   || arg == "--font-path")      { params.font_path      = argv[++i]; }
-        else if (arg == "-ocsv" || arg == "--output-csv")     { params.output_csv     = true; }
-        else if (arg == "-oj"   || arg == "--output-json")    { params.output_jsn     = true; }
-        else if (arg == "-of"   || arg == "--output-file")    { params.fname_out.emplace_back(argv[++i]); }
-        else if (arg == "-ps"   || arg == "--print-special")  { params.print_special  = true; }
-        else if (arg == "-pc"   || arg == "--print-colors")   { params.print_colors   = true; }
-        else if (arg == "-pp"   || arg == "--print-progress") { params.print_progress = true; }
-        else if (arg == "-nt"   || arg == "--no-timestamps")  { params.no_timestamps  = true; }
-        else if (arg == "-l"    || arg == "--language")       { params.language       = argv[++i]; }
-        else if (arg == "-dl"   || arg == "--detect-language"){ params.detect_language= true; }
-        else if (                  arg == "--prompt")         { params.prompt         = argv[++i]; }
-        else if (arg == "-m"    || arg == "--model")          { params.model          = argv[++i]; }
-        else if (arg == "-f"    || arg == "--file")           { params.fname_inp.emplace_back(argv[++i]); }
+        else if (arg == "-t"    || arg == "--threads")         { params.n_threads       = std::stoi(argv[++i]); }
+        else if (arg == "-p"    || arg == "--processors")      { params.n_processors    = std::stoi(argv[++i]); }
+        else if (arg == "-ot"   || arg == "--offset-t")        { params.offset_t_ms     = std::stoi(argv[++i]); }
+        else if (arg == "-on"   || arg == "--offset-n")        { params.offset_n        = std::stoi(argv[++i]); }
+        else if (arg == "-d"    || arg == "--duration")        { params.duration_ms     = std::stoi(argv[++i]); }
+        else if (arg == "-mc"   || arg == "--max-context")     { params.max_context     = std::stoi(argv[++i]); }
+        else if (arg == "-ml"   || arg == "--max-len")         { params.max_len         = std::stoi(argv[++i]); }
+        else if (arg == "-bo"   || arg == "--best-of")         { params.best_of         = std::stoi(argv[++i]); }
+        else if (arg == "-bs"   || arg == "--beam-size")       { params.beam_size       = std::stoi(argv[++i]); }
+        else if (arg == "-wt"   || arg == "--word-thold")      { params.word_thold      = std::stof(argv[++i]); }
+        else if (arg == "-et"   || arg == "--entropy-thold")   { params.entropy_thold   = std::stof(argv[++i]); }
+        else if (arg == "-lpt"  || arg == "--logprob-thold")   { params.logprob_thold   = std::stof(argv[++i]); }
+        else if (arg == "-su"   || arg == "--speed-up")        { params.speed_up        = true; }
+        else if (arg == "-tr"   || arg == "--translate")       { params.translate       = true; }
+        else if (arg == "-di"   || arg == "--diarize")         { params.diarize         = true; }
+        else if (arg == "-tdrz" || arg == "--tinydiarize")     { params.tinydiarize     = true; }
+        else if (arg == "-sow"  || arg == "--split-on-word")   { params.split_on_word   = true; }
+        else if (arg == "-nf"   || arg == "--no-fallback")     { params.no_fallback     = true; }
+        else if (arg == "-otxt" || arg == "--output-txt")      { params.output_txt      = true; }
+        else if (arg == "-ovtt" || arg == "--output-vtt")      { params.output_vtt      = true; }
+        else if (arg == "-osrt" || arg == "--output-srt")      { params.output_srt      = true; }
+        else if (arg == "-owts" || arg == "--output-words")    { params.output_wts      = true; }
+        else if (arg == "-olrc" || arg == "--output-lrc")      { params.output_lrc      = true; }
+        else if (arg == "-fp"   || arg == "--font-path")       { params.font_path       = argv[++i]; }
+        else if (arg == "-ocsv" || arg == "--output-csv")      { params.output_csv      = true; }
+        else if (arg == "-oj"   || arg == "--output-json")     { params.output_jsn      = true; }
+        else if (arg == "-of"   || arg == "--output-file")     { params.fname_out.emplace_back(argv[++i]); }
+        else if (arg == "-ps"   || arg == "--print-special")   { params.print_special   = true; }
+        else if (arg == "-pc"   || arg == "--print-colors")    { params.print_colors    = true; }
+        else if (arg == "-pp"   || arg == "--print-progress")  { params.print_progress  = true; }
+        else if (arg == "-nt"   || arg == "--no-timestamps")   { params.no_timestamps   = true; }
+        else if (arg == "-l"    || arg == "--language")        { params.language        = argv[++i]; }
+        else if (arg == "-dl"   || arg == "--detect-language") { params.detect_language = true; }
+        else if (                  arg == "--prompt")          { params.prompt          = argv[++i]; }
+        else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
+        else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
+        else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -182,6 +190,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -su,       --speed-up          [%-7s] speed up audio by x2 (reduced accuracy)\n",        params.speed_up ? "true" : "false");
     fprintf(stderr, "  -tr,       --translate         [%-7s] translate from source language to english\n",      params.translate ? "true" : "false");
     fprintf(stderr, "  -di,       --diarize           [%-7s] stereo audio diarization\n",                       params.diarize ? "true" : "false");
+    fprintf(stderr, "  -tdrz,     --tinydiarize       [%-7s] enable tinydiarize (requires a tdrz model)\n",     params.tinydiarize ? "true" : "false");
     fprintf(stderr, "  -nf,       --no-fallback       [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
     fprintf(stderr, "  -otxt,     --output-txt        [%-7s] output result in a text file\n",                   params.output_txt ? "true" : "false");
     fprintf(stderr, "  -ovtt,     --output-vtt        [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
@@ -201,6 +210,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt\n",                                 params.prompt.c_str());
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
+    fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
     fprintf(stderr, "\n");
 }
 
@@ -297,6 +307,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
             printf("%s%s", speaker.c_str(), text);
         }
 
+        if (params.tinydiarize) {
+            if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
+                printf("%s", params.tdrz_speaker_turn.c_str());
+            }
+        }
+
         // with timestamps or speakers: each segment on new line
         if (!params.no_timestamps || params.diarize) {
             printf("\n");
@@ -564,6 +580,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
             const int n_segments = whisper_full_n_segments(ctx);
             for (int i = 0; i < n_segments; ++i) {
                 const char * text = whisper_full_get_segment_text(ctx, i);
+
                 const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                 const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
@@ -576,11 +593,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
                         value_i("from", t0 * 10, false);
                         value_i("to", t1 * 10, true);
                     end_obj(false);
-                    value_s("text", text, !params.diarize);
+                    value_s("text", text, !params.diarize && !params.tinydiarize);
 
                     if (params.diarize && pcmf32s.size() == 2) {
                         value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
                     }
+
+                    if (params.tinydiarize) {
+                        value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
+                    }
                 end_obj(i == (n_segments - 1));
             }
 
@@ -777,6 +798,12 @@ int main(int argc, char ** argv) {
         exit(0);
     }
 
+    if (params.diarize && params.tinydiarize) {
+        fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
+        whisper_print_usage(argc, argv, params);
+        exit(0);
+    }
+
     // whisper init
 
     struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@@ -786,6 +813,9 @@ int main(int argc, char ** argv) {
         return 3;
     }
 
+    // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
+    whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
+
     for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
         const auto fname_inp = params.fname_inp[f];
                const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -818,11 +848,12 @@ int main(int argc, char ** argv) {
             if (params.detect_language) {
                 params.language = "auto";
             }
-            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
                     __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
                     params.n_threads, params.n_processors,
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
+                    params.tinydiarize ? "tdrz = 1, " : "",
                     params.no_timestamps ? 0 : 1);
 
             fprintf(stderr, "\n");
@@ -853,6 +884,8 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
+            wparams.tdrz_enable      = params.tinydiarize; // [TDRZ]
+
             wparams.initial_prompt   = params.prompt.c_str();
 
             wparams.greedy.best_of        = params.best_of;
index 610180ab7ef08a49706a5b88bebd0f03cd39662a..8b0f1f60ab8a7a7316c59dbc4ae3ea9b0a2380a4 100644 (file)
@@ -3,6 +3,10 @@
 #include "coreml/whisper-encoder.h"
 #endif
 
+#if WHISPER_USE_OPENVINO
+#include "openvino/whisper-openvino-encoder.h"
+#endif
+
 #include "ggml.h"
 
 #include <algorithm>
@@ -380,16 +384,18 @@ struct whisper_vocab {
     std::map<token, id> token_to_id;
     std::map<id, token> id_to_token;
 
-    id token_eot  = 50256;
-    id token_sot  = 50257;
-    id token_prev = 50360;
-    id token_solm = 50361; // ??
-    id token_not  = 50362; // no timestamps
-    id token_beg  = 50363;
-
-    // available tasks
-    static const id token_translate  = 50358;
-    static const id token_transcribe = 50359;
+    // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
+    id token_eot        = 50256;
+    id token_sot        = 50257;
+    // task tokens (used only for multilingual models)
+    id token_translate  = 50357;
+    id token_transcribe = 50358;
+    // other special tokens
+    id token_solm       = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
+    id token_prev       = 50360;
+    id token_nosp       = 50361;
+    id token_not        = 50362; // no timestamps
+    id token_beg        = 50363; // begin timestamps
 
     bool is_multilingual() const {
         return n_vocab == 51865;
@@ -403,6 +409,8 @@ struct whisper_segment {
     std::string text;
 
     std::vector<whisper_token_data> tokens;
+
+    bool speaker_turn_next;
 };
 
 // medium
@@ -656,6 +664,10 @@ struct whisper_state {
     whisper_coreml_context * ctx_coreml = nullptr;
 #endif
 
+#ifdef WHISPER_USE_OPENVINO
+    whisper_openvino_context * ctx_openvino = nullptr;
+#endif
+
     // [EXPERIMENTAL] token-level timestamps data
     int64_t t_beg = 0;
     int64_t t_last = 0;
@@ -966,8 +978,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         if (vocab.is_multilingual()) {
             vocab.token_eot++;
             vocab.token_sot++;
-            vocab.token_prev++;
+            vocab.token_translate++;
+            vocab.token_transcribe++;
             vocab.token_solm++;
+            vocab.token_prev++;
+            vocab.token_nosp++;
             vocab.token_not++;
             vocab.token_beg++;
         }
@@ -981,8 +996,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                     word = "[_EOT_]";
                 } else if (i == vocab.token_sot) {
                     word = "[_SOT_]";
+                } else if (i == vocab.token_solm) {
+                    word = "[_SOLM_]";
                 } else if (i == vocab.token_prev) {
                     word = "[_PREV_]";
+                } else if (i == vocab.token_nosp) {
+                    word = "[_NOSP_]";
                 } else if (i == vocab.token_not) {
                     word = "[_NOT_]";
                 } else if (i == vocab.token_beg) {
@@ -1467,7 +1486,13 @@ static bool whisper_encode_internal(
     const bool use_coreml = wstate.ctx_coreml != nullptr;
 #endif
 
-    if (!use_coreml) {
+#ifndef WHISPER_USE_OPENVINO
+    const bool use_openvino = false;
+#else
+    const bool use_openvino = wstate.ctx_openvino != nullptr;
+#endif
+
+    if (!use_coreml && !use_openvino) {
         // convolution + gelu
         {
             wstate.use_buf(ctx0, 1);
@@ -1766,8 +1791,7 @@ static bool whisper_encode_internal(
         }
     }
 #ifdef WHISPER_USE_COREML
-    else
-    {
+    else if (use_coreml) {
         wstate.use_buf(ctx0, -1);
 
         cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
@@ -1775,6 +1799,17 @@ static bool whisper_encode_internal(
         whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
     }
 #endif
+#ifdef WHISPER_USE_OPENVINO
+    else if (use_openvino) {
+        wstate.use_buf(ctx0, -1);
+
+        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+
+        if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
+            return false;
+        }
+    }
+#endif
 
     // cur
     //{
@@ -2617,6 +2652,31 @@ static std::string whisper_get_coreml_path_encoder(std::string path_bin) {
 }
 #endif
 
+#ifdef WHISPER_USE_OPENVINO
+// replace .bin with-encoder-openvino.xml
+static std::string whisper_openvino_get_path_encoder(std::string path_bin) {
+    auto pos = path_bin.rfind('.');
+    if (pos != std::string::npos) {
+        path_bin = path_bin.substr(0, pos);
+    }
+
+    path_bin += "-encoder-openvino.xml";
+
+    return path_bin;
+}
+
+static std::string whisper_openvino_get_path_cache(std::string path_bin) {
+    auto pos = path_bin.rfind('.');
+    if (pos != std::string::npos) {
+        path_bin = path_bin.substr(0, pos);
+    }
+
+    path_bin += "-encoder-openvino-cache";
+
+    return path_bin;
+}
+#endif
+
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
     whisper_state * state = new whisper_state;
 
@@ -2683,6 +2743,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     return state;
 }
 
+int whisper_ctx_init_openvino_encoder(
+        struct whisper_context * ctx,
+                    const char * model_path,
+                    const char * device,
+                    const char * cache_dir) {
+#ifndef WHISPER_USE_OPENVINO
+    (void)(ctx);
+    (void)(model_path);
+    (void)(device);
+    (void)(cache_dir);
+
+    return 1;
+#else
+    if (!model_path && ctx->path_model.empty()) {
+        fprintf(stderr, "%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
+        return 1;
+    }
+
+    std::string path_encoder;
+    if (!model_path) {
+        //if model_path is not set, attempt to find it in the same directory as ggml-<model>.bin model
+        path_encoder = whisper_openvino_get_path_encoder(ctx->path_model);
+    } else {
+        path_encoder = model_path;
+    }
+
+    std::string path_cache;
+    if (!cache_dir) {
+        //if cache_dir is not set, set it as a dir residing next to ggml-<model>.bin
+        path_cache = whisper_openvino_get_path_cache(ctx->path_model);
+    } else {
+        path_cache = cache_dir;
+    }
+
+    fprintf(stderr, "%s: loading OpenVINO model from '%s'\n", __func__, path_openvino.c_str());
+    fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
+
+    ctx->state->ctx_openvino = whisper_openvino_init(path_openvino.c_str(), device, path_cache.c_str());
+    if (!ctx->state->ctx_openvino) {
+        fprintf(stderr, "%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_openvino.c_str());
+        return 1;
+    } else {
+        fprintf(stderr, "%s: OpenVINO model loaded\n", __func__);
+    }
+
+    return 0;
+#endif
+}
+
 struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
 
     fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@@ -2837,6 +2946,13 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
+#ifdef WHISPER_USE_OPENVINO
+        if (state->ctx_openvino != nullptr) {
+            whisper_openvino_free(state->ctx_openvino);
+            state->ctx_openvino = nullptr;
+        }
+#endif
+
         delete state;
     }
 }
@@ -3208,12 +3324,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) {
     return ctx->vocab.token_sot;
 }
 
+whisper_token whisper_token_solm(struct whisper_context * ctx) {
+    return ctx->vocab.token_solm;
+}
+
 whisper_token whisper_token_prev(struct whisper_context * ctx) {
     return ctx->vocab.token_prev;
 }
 
-whisper_token whisper_token_solm(struct whisper_context * ctx) {
-    return ctx->vocab.token_solm;
+whisper_token whisper_token_nosp(struct whisper_context * ctx) {
+    return ctx->vocab.token_nosp;
 }
 
 whisper_token whisper_token_not(struct whisper_context * ctx) {
@@ -3228,12 +3348,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
     return whisper_token_sot(ctx) + 1 + lang_id;
 }
 
-whisper_token whisper_token_translate(void) {
-    return whisper_vocab::token_translate;
+whisper_token whisper_token_translate(struct whisper_context * ctx) {
+    return ctx->vocab.token_translate;
 }
 
-whisper_token whisper_token_transcribe(void) {
-    return whisper_vocab::token_transcribe;
+whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
+    return ctx->vocab.token_transcribe;
 }
 
 void whisper_print_timings(struct whisper_context * ctx) {
@@ -3272,6 +3392,14 @@ static int whisper_has_coreml(void) {
 #endif
 }
 
+static int whisper_has_openvino(void) {
+#ifdef WHISPER_USE_OPENVINO
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 const char * whisper_print_system_info(void) {
     static std::string s;
 
@@ -3289,6 +3417,7 @@ const char * whisper_print_system_info(void) {
     s += "SSE3 = "      + std::to_string(ggml_cpu_has_sse3())      + " | ";
     s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
     s += "COREML = "    + std::to_string(whisper_has_coreml())     + " | ";
+    s += "OPENVINO = "  + std::to_string(whisper_has_openvino())   + " | ";
 
     return s.c_str();
 }
@@ -3305,51 +3434,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam
 
 struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params result = {
-        /*.strategy         =*/ strategy,
-
-        /*.n_threads        =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-        /*.n_max_text_ctx   =*/ 16384,
-        /*.offset_ms        =*/ 0,
-        /*.duration_ms      =*/ 0,
-
-        /*.translate        =*/ false,
-        /*.no_context       =*/ true,
-        /*.single_segment   =*/ false,
-        /*.print_special    =*/ false,
-        /*.print_progress   =*/ true,
-        /*.print_realtime   =*/ false,
-        /*.print_timestamps =*/ true,
-
-        /*.token_timestamps =*/ false,
-        /*.thold_pt         =*/ 0.01f,
-        /*.thold_ptsum      =*/ 0.01f,
-        /*.max_len          =*/ 0,
-        /*.split_on_word    =*/ false,
-        /*.max_tokens       =*/ 0,
-
-        /*.speed_up         =*/ false,
-        /*.audio_ctx        =*/ 0,
-
-        /*.initial_prompt   =*/ nullptr,
-        /*.prompt_tokens    =*/ nullptr,
-        /*.prompt_n_tokens  =*/ 0,
-
-        /*.language         =*/ "en",
-        /*.detect_language  =*/ false,
-
-        /*.suppress_blank   =*/ true,
+        /*.strategy          =*/ strategy,
+
+        /*.n_threads         =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+        /*.n_max_text_ctx    =*/ 16384,
+        /*.offset_ms         =*/ 0,
+        /*.duration_ms       =*/ 0,
+
+        /*.translate         =*/ false,
+        /*.no_context        =*/ true,
+        /*.single_segment    =*/ false,
+        /*.print_special     =*/ false,
+        /*.print_progress    =*/ true,
+        /*.print_realtime    =*/ false,
+        /*.print_timestamps  =*/ true,
+
+        /*.token_timestamps  =*/ false,
+        /*.thold_pt          =*/ 0.01f,
+        /*.thold_ptsum       =*/ 0.01f,
+        /*.max_len           =*/ 0,
+        /*.split_on_word     =*/ false,
+        /*.max_tokens        =*/ 0,
+
+        /*.speed_up          =*/ false,
+        /*.audio_ctx         =*/ 0,
+
+        /*.tdrz_enable       =*/ false,
+
+        /*.initial_prompt    =*/ nullptr,
+        /*.prompt_tokens     =*/ nullptr,
+        /*.prompt_n_tokens   =*/ 0,
+
+        /*.language          =*/ "en",
+        /*.detect_language   =*/ false,
+
+        /*.suppress_blank    =*/ true,
         /*.suppress_non_speech_tokens =*/ false,
 
-        /*.temperature      =*/  0.0f,
-        /*.max_initial_ts   =*/  1.0f,
-        /*.length_penalty   =*/ -1.0f,
+        /*.temperature       =*/  0.0f,
+        /*.max_initial_ts    =*/  1.0f,
+        /*.length_penalty    =*/ -1.0f,
 
-        /*.temperature_inc  =*/  0.4f,
-        /*.entropy_thold    =*/  2.4f,
-        /*.logprob_thold    =*/ -1.0f,
-        /*.no_speech_thold  =*/  0.6f,
+        /*.temperature_inc   =*/  0.4f,
+        /*.entropy_thold     =*/  2.4f,
+        /*.logprob_thold     =*/ -1.0f,
+        /*.no_speech_thold   =*/  0.6f,
 
-        /*.greedy           =*/ {
+        /*.greedy            =*/ {
             /*.best_of   =*/ -1,
         },
 
@@ -3401,26 +3532,6 @@ static void whisper_exp_compute_token_level_timestamps(
                          float   thold_pt,
                          float   thold_ptsum);
 
-// trim from start (in place)
-static inline void ltrim(std::string &s) {
-    s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
-        return std::isspace(ch);
-    }));
-}
-
-// trim from end (in place)
-static inline void rtrim(std::string &s) {
-    s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
-        return std::isspace(ch);
-    }).base(), s.end());
-}
-
-// trim from both ends (in place)
-static inline void trim(std::string &s) {
-    rtrim(s);
-    ltrim(s);
-}
-
 static inline bool should_split_on_word(const char * txt, bool split_on_word) {
     if (!split_on_word) return true;
 
@@ -3447,14 +3558,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
         const int cur = strlen(txt);
 
         if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
-            // split here
-            if (split_on_word) {
-                trim(text);
-            }
-
             state.result_all.back().text = std::move(text);
             state.result_all.back().t1 = token.t0;
             state.result_all.back().tokens.resize(i);
+            state.result_all.back().speaker_turn_next = false;
 
             state.result_all.push_back({});
             state.result_all.back().t0 = token.t0;
@@ -3466,6 +3573,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
                     segment.tokens.begin() + i,
                     segment.tokens.end());
 
+            state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
+
             acc = 0;
             text = "";
 
@@ -3479,9 +3588,6 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
         }
     }
 
-    if (split_on_word) {
-        trim(text);
-    }
     state.result_all.back().text = std::move(text);
 
     return res;
@@ -3547,9 +3653,14 @@ static void whisper_process_logits(
         // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
         logits[vocab.token_not] = -INFINITY;
 
-        // suppress sot and solm tokens
+        // suppress sot and nosp tokens
         logits[vocab.token_sot]  = -INFINITY;
-        logits[vocab.token_solm] = -INFINITY;
+        logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
+
+        // [TDRZ] when tinydiarize is disabled, suppress solm token
+        if (params.tdrz_enable == false) {
+            logits[vocab.token_solm] = -INFINITY;
+        }
 
         // suppress task tokens
         logits[vocab.token_translate]  = -INFINITY;
@@ -4046,9 +4157,9 @@ int whisper_full_with_state(
         state->lang_id = lang_id;
         prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
-            prompt_init.push_back(whisper_token_translate());
+            prompt_init.push_back(whisper_token_translate(ctx));
         } else {
-            prompt_init.push_back(whisper_token_transcribe());
+            prompt_init.push_back(whisper_token_transcribe(ctx));
         }
     }
 
@@ -4528,23 +4639,27 @@ int whisper_full_with_state(
                 prompt_past.push_back(tokens_cur[i].id);
             }
 
-            // store the text from this iteration
             if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
                 int  i0 = 0;
                 auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
 
                 std::string text;
+                bool speaker_turn_next = false;
 
                 for (int i = 0; i < (int) tokens_cur.size(); i++) {
                     //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
                     //        ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
                     //        ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
 
-                    if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
-                    } else {
+                    if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
                         text += whisper_token_to_str(ctx, tokens_cur[i].id);
                     }
 
+                    // [TDRZ] record if speaker turn was predicted after current segment
+                    if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
+                        speaker_turn_next = true;
+                    }
+
                     if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
                         const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
 
@@ -4563,7 +4678,7 @@ int whisper_full_with_state(
 
                             //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
 
-                            result_all.push_back({ tt0, tt1, text, {} });
+                            result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
                             for (int j = i0; j <= i; j++) {
                                 result_all.back().tokens.push_back(tokens_cur[j]);
                             }
@@ -4589,6 +4704,7 @@ int whisper_full_with_state(
                         i--;
                         t0 = t1;
                         i0 = i + 1;
+                        speaker_turn_next = false;
                     }
                 }
 
@@ -4607,7 +4723,7 @@ int whisper_full_with_state(
                         }
                     }
 
-                    result_all.push_back({ tt0, tt1, text, {} });
+                    result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
                     for (int j = i0; j < (int) tokens_cur.size(); j++) {
                         result_all.back().tokens.push_back(tokens_cur[j]);
                     }
@@ -4787,6 +4903,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
     return ctx->state->result_all[i_segment].t1;
 }
 
+bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
+    return ctx->state->result_all[i_segment].speaker_turn_next;
+}
+
 const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
     return state->result_all[i_segment].text.c_str();
 }
index e983c7d4fa323f65ac9912e1e53982367cf2a8ea..83af11bd848c66ff1fd6b5c9e97eec6f1484165d 100644 (file)
@@ -110,6 +110,23 @@ extern "C" {
 
     WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
 
+    // Given a context, enable use of OpenVINO for encode inference.
+    // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
+    //                      the path will be generated from the ggml model path that was passed
+    //                      in to whisper_init_from_file. For example, if 'path_model' was
+    //                      "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
+    //                      assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
+    // device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
+    // cache_dir: Optional cache directory that can speed up init time, especially for
+    //                     GPU, by caching compiled 'blobs' there.
+    //                     Set to nullptr if not used.
+    // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
+    WHISPER_API int whisper_ctx_init_openvino_encoder(
+        struct whisper_context * ctx,
+                    const char * model_path,
+                    const char * device,
+                    const char * cache_dir);
+
     // Frees all allocated memory
     WHISPER_API void whisper_free      (struct whisper_context * ctx);
     WHISPER_API void whisper_free_state(struct whisper_state * state);
@@ -277,15 +294,16 @@ extern "C" {
     // Special tokens
     WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
-    WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
+    WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
+    WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
 
     // Task tokens
-    WHISPER_API whisper_token whisper_token_translate (void);
-    WHISPER_API whisper_token whisper_token_transcribe(void);
+    WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
+    WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
 
     // Performance information from the default state.
     WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
@@ -358,6 +376,9 @@ extern "C" {
         bool speed_up;          // speed-up the audio by 2x using Phase Vocoder
         int  audio_ctx;         // overwrite the audio context size (0 = use default)
 
+        // [EXPERIMENTAL] [TDRZ] tinydiarize
+        bool tdrz_enable;       // enable tinydiarize speaker turn detection
+
         // tokens to provide to the whisper decoder as initial prompt
         // these are prepended to any existing text context from a previous call
         const char * initial_prompt;
@@ -460,6 +481,9 @@ extern "C" {
     WHISPER_API int64_t whisper_full_get_segment_t1           (struct whisper_context * ctx, int i_segment);
     WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
 
+    // Get whether the next segment is predicted as a speaker turn
+    WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
+
     // Get the text of the specified segment
     WHISPER_API const char * whisper_full_get_segment_text           (struct whisper_context * ctx, int i_segment);
     WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
@@ -488,9 +512,9 @@ extern "C" {
 
     // Temporary helpers needed for exposing ggml interface
 
-    WHISPER_API int whisper_bench_memcpy(int n_threads);
-    WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
-    WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
+    WHISPER_API int          whisper_bench_memcpy          (int n_threads);
+    WHISPER_API const char * whisper_bench_memcpy_str      (int n_threads);
+    WHISPER_API int          whisper_bench_ggml_mul_mat    (int n_threads);
     WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
 
 #ifdef __cplusplus