From: Georgi Gerganov Date: Tue, 4 Jul 2023 17:24:22 +0000 (+0300) Subject: whisper : sync whisper.cpp (tinydiarize + OpenVINO) X-Git-Tag: upstream/0.0.1642~1353 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ba8547430eaa3484c928686aff5ccf0bfa0f1f69;p=pkg%2Fggml%2Fsources%2Fggml whisper : sync whisper.cpp (tinydiarize + OpenVINO) --- diff --git a/examples/common.h b/examples/common.h index 74655cbf..f9740a3c 100644 --- a/examples/common.h +++ b/examples/common.h @@ -15,25 +15,24 @@ // struct gpt_params { - int32_t seed = -1; // RNG seed + int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 200; // new tokens to predict + int32_t n_batch = 8; // batch size for prompt processing // sampling parameters - int32_t top_k = 40; - float top_p = 0.9f; - float temp = 0.9f; + int32_t top_k = 40; + float top_p = 0.9f; + float temp = 0.9f; int32_t repeat_last_n = 64; float repeat_penalty = 1.00f; - int32_t n_batch = 8; // batch size for prompt processing - std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path std::string prompt = ""; std::string token_test = ""; - bool interactive = false; - int interactive_port = -1; + bool interactive = false; + int32_t interactive_port = -1; }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index ff62f74b..8dd31d02 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -68,28 +68,34 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; - bool speed_up = false; - bool translate = false; - bool detect_language= false; - bool diarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool output_txt = false; - bool output_vtt = false; - bool output_srt = false; - bool output_wts = false; - bool output_csv = false; - bool output_jsn = false; - bool output_lrc = false; - bool print_special = false; - bool print_colors = false; - bool print_progress = false; - bool no_timestamps = false; - - std::string language = "en"; + bool speed_up = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool output_csv = false; + bool output_jsn = false; + bool output_lrc = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + + std::string language = "en"; std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; + std::string model = "models/ggml-base.en.bin"; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + + std::string openvino_encode_device = "CPU"; std::vector fname_inp = {}; std::vector fname_out = {}; @@ -115,41 +121,43 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -182,6 +190,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); @@ -201,6 +210,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); fprintf(stderr, "\n"); } @@ -297,6 +307,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper printf("%s%s", speaker.c_str(), text); } + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + // with timestamps or speakers: each segment on new line if (!params.no_timestamps || params.diarize) { printf("\n"); @@ -564,6 +580,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); @@ -576,11 +593,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper value_i("from", t0 * 10, false); value_i("to", t1 * 10, true); end_obj(false); - value_s("text", text, !params.diarize); + value_s("text", text, !params.diarize && !params.tinydiarize); if (params.diarize && pcmf32s.size() == 2) { value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); } + + if (params.tinydiarize) { + value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true); + } end_obj(i == (n_segments - 1)); } @@ -777,6 +798,12 @@ int main(int argc, char ** argv) { exit(0); } + if (params.diarize && params.tinydiarize) { + fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + // whisper init struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); @@ -786,6 +813,9 @@ int main(int argc, char ** argv) { return 3; } + // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured + whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -818,11 +848,12 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", + params.tinydiarize ? "tdrz = 1, " : "", params.no_timestamps ? 0 : 1); fprintf(stderr, "\n"); @@ -853,6 +884,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 610180ab..8b0f1f60 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -3,6 +3,10 @@ #include "coreml/whisper-encoder.h" #endif +#if WHISPER_USE_OPENVINO +#include "openvino/whisper-openvino-encoder.h" +#endif + #include "ggml.h" #include @@ -380,16 +384,18 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - id token_eot = 50256; - id token_sot = 50257; - id token_prev = 50360; - id token_solm = 50361; // ?? - id token_not = 50362; // no timestamps - id token_beg = 50363; - - // available tasks - static const id token_translate = 50358; - static const id token_transcribe = 50359; + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps bool is_multilingual() const { return n_vocab == 51865; @@ -403,6 +409,8 @@ struct whisper_segment { std::string text; std::vector tokens; + + bool speaker_turn_next; }; // medium @@ -656,6 +664,10 @@ struct whisper_state { whisper_coreml_context * ctx_coreml = nullptr; #endif +#ifdef WHISPER_USE_OPENVINO + whisper_openvino_context * ctx_openvino = nullptr; +#endif + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg = 0; int64_t t_last = 0; @@ -966,8 +978,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (vocab.is_multilingual()) { vocab.token_eot++; vocab.token_sot++; - vocab.token_prev++; + vocab.token_translate++; + vocab.token_transcribe++; vocab.token_solm++; + vocab.token_prev++; + vocab.token_nosp++; vocab.token_not++; vocab.token_beg++; } @@ -981,8 +996,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_solm) { + word = "[_SOLM_]"; } else if (i == vocab.token_prev) { word = "[_PREV_]"; + } else if (i == vocab.token_nosp) { + word = "[_NOSP_]"; } else if (i == vocab.token_not) { word = "[_NOT_]"; } else if (i == vocab.token_beg) { @@ -1467,7 +1486,13 @@ static bool whisper_encode_internal( const bool use_coreml = wstate.ctx_coreml != nullptr; #endif - if (!use_coreml) { +#ifndef WHISPER_USE_OPENVINO + const bool use_openvino = false; +#else + const bool use_openvino = wstate.ctx_openvino != nullptr; +#endif + + if (!use_coreml && !use_openvino) { // convolution + gelu { wstate.use_buf(ctx0, 1); @@ -1766,8 +1791,7 @@ static bool whisper_encode_internal( } } #ifdef WHISPER_USE_COREML - else - { + else if (use_coreml) { wstate.use_buf(ctx0, -1); cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); @@ -1775,6 +1799,17 @@ static bool whisper_encode_internal( whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); } #endif +#ifdef WHISPER_USE_OPENVINO + else if (use_openvino) { + wstate.use_buf(ctx0, -1); + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + + if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) { + return false; + } + } +#endif // cur //{ @@ -2617,6 +2652,31 @@ static std::string whisper_get_coreml_path_encoder(std::string path_bin) { } #endif +#ifdef WHISPER_USE_OPENVINO +// replace .bin with-encoder-openvino.xml +static std::string whisper_openvino_get_path_encoder(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino.xml"; + + return path_bin; +} + +static std::string whisper_openvino_get_path_cache(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino-cache"; + + return path_bin; +} +#endif + struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; @@ -2683,6 +2743,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return state; } +int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir) { +#ifndef WHISPER_USE_OPENVINO + (void)(ctx); + (void)(model_path); + (void)(device); + (void)(cache_dir); + + return 1; +#else + if (!model_path && ctx->path_model.empty()) { + fprintf(stderr, "%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + return 1; + } + + std::string path_encoder; + if (!model_path) { + //if model_path is not set, attempt to find it in the same directory as ggml-.bin model + path_encoder = whisper_openvino_get_path_encoder(ctx->path_model); + } else { + path_encoder = model_path; + } + + std::string path_cache; + if (!cache_dir) { + //if cache_dir is not set, set it as a dir residing next to ggml-.bin + path_cache = whisper_openvino_get_path_cache(ctx->path_model); + } else { + path_cache = cache_dir; + } + + fprintf(stderr, "%s: loading OpenVINO model from '%s'\n", __func__, path_openvino.c_str()); + fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__); + + ctx->state->ctx_openvino = whisper_openvino_init(path_openvino.c_str(), device, path_cache.c_str()); + if (!ctx->state->ctx_openvino) { + fprintf(stderr, "%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_openvino.c_str()); + return 1; + } else { + fprintf(stderr, "%s: OpenVINO model loaded\n", __func__); + } + + return 0; +#endif +} + struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); @@ -2837,6 +2946,13 @@ void whisper_free_state(struct whisper_state * state) } #endif +#ifdef WHISPER_USE_OPENVINO + if (state->ctx_openvino != nullptr) { + whisper_openvino_free(state->ctx_openvino); + state->ctx_openvino = nullptr; + } +#endif + delete state; } } @@ -3208,12 +3324,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) { return ctx->vocab.token_sot; } +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + whisper_token whisper_token_prev(struct whisper_context * ctx) { return ctx->vocab.token_prev; } -whisper_token whisper_token_solm(struct whisper_context * ctx) { - return ctx->vocab.token_solm; +whisper_token whisper_token_nosp(struct whisper_context * ctx) { + return ctx->vocab.token_nosp; } whisper_token whisper_token_not(struct whisper_context * ctx) { @@ -3228,12 +3348,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { return whisper_token_sot(ctx) + 1 + lang_id; } -whisper_token whisper_token_translate(void) { - return whisper_vocab::token_translate; +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; } -whisper_token whisper_token_transcribe(void) { - return whisper_vocab::token_transcribe; +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; } void whisper_print_timings(struct whisper_context * ctx) { @@ -3272,6 +3392,14 @@ static int whisper_has_coreml(void) { #endif } +static int whisper_has_openvino(void) { +#ifdef WHISPER_USE_OPENVINO + return 1; +#else + return 0; +#endif +} + const char * whisper_print_system_info(void) { static std::string s; @@ -3289,6 +3417,7 @@ const char * whisper_print_system_info(void) { s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; + s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; return s.c_str(); } @@ -3305,51 +3434,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result = { - /*.strategy =*/ strategy, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ true, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.split_on_word =*/ false, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - /*.detect_language =*/ false, - - /*.suppress_blank =*/ true, + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ true, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.tdrz_enable =*/ false, + + /*.initial_prompt =*/ nullptr, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + /*.detect_language =*/ false, + + /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ false, - /*.temperature =*/ 0.0f, - /*.max_initial_ts =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, - /*.entropy_thold =*/ 2.4f, - /*.logprob_thold =*/ -1.0f, - /*.no_speech_thold =*/ 0.6f, + /*.temperature_inc =*/ 0.4f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, - /*.greedy =*/ { + /*.greedy =*/ { /*.best_of =*/ -1, }, @@ -3401,26 +3532,6 @@ static void whisper_exp_compute_token_level_timestamps( float thold_pt, float thold_ptsum); -// trim from start (in place) -static inline void ltrim(std::string &s) { - s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) { - return std::isspace(ch); - })); -} - -// trim from end (in place) -static inline void rtrim(std::string &s) { - s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) { - return std::isspace(ch); - }).base(), s.end()); -} - -// trim from both ends (in place) -static inline void trim(std::string &s) { - rtrim(s); - ltrim(s); -} - static inline bool should_split_on_word(const char * txt, bool split_on_word) { if (!split_on_word) return true; @@ -3447,14 +3558,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta const int cur = strlen(txt); if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { - // split here - if (split_on_word) { - trim(text); - } - state.result_all.back().text = std::move(text); state.result_all.back().t1 = token.t0; state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; state.result_all.push_back({}); state.result_all.back().t0 = token.t0; @@ -3466,6 +3573,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta segment.tokens.begin() + i, segment.tokens.end()); + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + acc = 0; text = ""; @@ -3479,9 +3588,6 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta } } - if (split_on_word) { - trim(text); - } state.result_all.back().text = std::move(text); return res; @@ -3547,9 +3653,14 @@ static void whisper_process_logits( // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; - // suppress sot and solm tokens + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_solm] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + + // [TDRZ] when tinydiarize is disabled, suppress solm token + if (params.tdrz_enable == false) { + logits[vocab.token_solm] = -INFINITY; + } // suppress task tokens logits[vocab.token_translate] = -INFINITY; @@ -4046,9 +4157,9 @@ int whisper_full_with_state( state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { - prompt_init.push_back(whisper_token_translate()); + prompt_init.push_back(whisper_token_translate(ctx)); } else { - prompt_init.push_back(whisper_token_transcribe()); + prompt_init.push_back(whisper_token_transcribe(ctx)); } } @@ -4528,23 +4639,27 @@ int whisper_full_with_state( prompt_past.push_back(tokens_cur[i].id); } - // store the text from this iteration if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; + bool speaker_turn_next = false; for (int i = 0; i < (int) tokens_cur.size(); i++) { //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { text += whisper_token_to_str(ctx, tokens_cur[i].id); } + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + speaker_turn_next = true; + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); @@ -4563,7 +4678,7 @@ int whisper_full_with_state( //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4589,6 +4704,7 @@ int whisper_full_with_state( i--; t0 = t1; i0 = i + 1; + speaker_turn_next = false; } } @@ -4607,7 +4723,7 @@ int whisper_full_with_state( } } - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4787,6 +4903,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) return ctx->state->result_all[i_segment].t1; } +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].speaker_turn_next; +} + const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { return state->result_all[i_segment].text.c_str(); } diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index e983c7d4..83af11bd 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -110,6 +110,23 @@ extern "C" { WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + // Given a context, enable use of OpenVINO for encode inference. + // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr, + // the path will be generated from the ggml model path that was passed + // in to whisper_init_from_file. For example, if 'path_model' was + // "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be + // assumed to be "/path/to/ggml-base.en-encoder-openvino.xml". + // device: OpenVINO device to run inference on ("CPU", "GPU", etc.) + // cache_dir: Optional cache directory that can speed up init time, especially for + // GPU, by caching compiled 'blobs' there. + // Set to nullptr if not used. + // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1. + WHISPER_API int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir); + // Frees all allocated memory WHISPER_API void whisper_free (struct whisper_context * ctx); WHISPER_API void whisper_free_state(struct whisper_state * state); @@ -277,15 +294,16 @@ extern "C" { // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx); @@ -358,6 +376,9 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const char * initial_prompt; @@ -460,6 +481,9 @@ extern "C" { WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + // Get the text of the specified segment WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); @@ -488,9 +512,9 @@ extern "C" { // Temporary helpers needed for exposing ggml interface - WHISPER_API int whisper_bench_memcpy(int n_threads); - WHISPER_API const char * whisper_bench_memcpy_str(int n_threads); - WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); #ifdef __cplusplus