//
struct gpt_params {
- int32_t seed = -1; // RNG seed
+ int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
+ int32_t n_batch = 8; // batch size for prompt processing
// sampling parameters
- int32_t top_k = 40;
- float top_p = 0.9f;
- float temp = 0.9f;
+ int32_t top_k = 40;
+ float top_p = 0.9f;
+ float temp = 0.9f;
int32_t repeat_last_n = 64;
float repeat_penalty = 1.00f;
- int32_t n_batch = 8; // batch size for prompt processing
-
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt = "";
std::string token_test = "";
- bool interactive = false;
- int interactive_port = -1;
+ bool interactive = false;
+ int32_t interactive_port = -1;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
- bool speed_up = false;
- bool translate = false;
- bool detect_language= false;
- bool diarize = false;
- bool split_on_word = false;
- bool no_fallback = false;
- bool output_txt = false;
- bool output_vtt = false;
- bool output_srt = false;
- bool output_wts = false;
- bool output_csv = false;
- bool output_jsn = false;
- bool output_lrc = false;
- bool print_special = false;
- bool print_colors = false;
- bool print_progress = false;
- bool no_timestamps = false;
-
- std::string language = "en";
+ bool speed_up = false;
+ bool translate = false;
+ bool detect_language = false;
+ bool diarize = false;
+ bool tinydiarize = false;
+ bool split_on_word = false;
+ bool no_fallback = false;
+ bool output_txt = false;
+ bool output_vtt = false;
+ bool output_srt = false;
+ bool output_wts = false;
+ bool output_csv = false;
+ bool output_jsn = false;
+ bool output_lrc = false;
+ bool print_special = false;
+ bool print_colors = false;
+ bool print_progress = false;
+ bool no_timestamps = false;
+
+ std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
- std::string model = "models/ggml-base.en.bin";
+ std::string model = "models/ggml-base.en.bin";
+
+ // [TDRZ] speaker turn string
+ std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
+
+ std::string openvino_encode_device = "CPU";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
whisper_print_usage(argc, argv, params);
exit(0);
}
- else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
- else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
- else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
- else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
- else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
- else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
- else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
- else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
- else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
- else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
- else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
- else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
- else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
- else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
- else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
- else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
- else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
- else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
- else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
- else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
- else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
- else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
- else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
- else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
- else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
- else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
- else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
- else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
- else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
- else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
- else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
- else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
- else if ( arg == "--prompt") { params.prompt = argv[++i]; }
- else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
- else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
+ else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
+ else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
+ else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
+ else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
+ else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
+ else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
+ else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
+ else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
+ else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
+ else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
+ else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
+ else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
+ else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
+ else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
+ else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
+ else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
+ else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
+ else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
+ else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
+ else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
+ else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
+ else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
+ else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
+ else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
+ else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
+ else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
+ else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
+ else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
+ else if ( arg == "--prompt") { params.prompt = argv[++i]; }
+ else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
+ else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
+ else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
+ fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
+ fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, "\n");
}
printf("%s%s", speaker.c_str(), text);
}
+ if (params.tinydiarize) {
+ if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
+ printf("%s", params.tdrz_speaker_turn.c_str());
+ }
+ }
+
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
+
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true);
end_obj(false);
- value_s("text", text, !params.diarize);
+ value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
}
+
+ if (params.tinydiarize) {
+ value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
+ }
end_obj(i == (n_segments - 1));
}
exit(0);
}
+ if (params.diarize && params.tinydiarize) {
+ fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
return 3;
}
+ // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
+ whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
+
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
if (params.detect_language) {
params.language = "auto";
}
- fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
+ params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
wparams.speed_up = params.speed_up;
+ wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
+
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
#include "coreml/whisper-encoder.h"
#endif
+#if WHISPER_USE_OPENVINO
+#include "openvino/whisper-openvino-encoder.h"
+#endif
+
#include "ggml.h"
#include <algorithm>
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
- id token_eot = 50256;
- id token_sot = 50257;
- id token_prev = 50360;
- id token_solm = 50361; // ??
- id token_not = 50362; // no timestamps
- id token_beg = 50363;
-
- // available tasks
- static const id token_translate = 50358;
- static const id token_transcribe = 50359;
+ // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
+ id token_eot = 50256;
+ id token_sot = 50257;
+ // task tokens (used only for multilingual models)
+ id token_translate = 50357;
+ id token_transcribe = 50358;
+ // other special tokens
+ id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
+ id token_prev = 50360;
+ id token_nosp = 50361;
+ id token_not = 50362; // no timestamps
+ id token_beg = 50363; // begin timestamps
bool is_multilingual() const {
return n_vocab == 51865;
std::string text;
std::vector<whisper_token_data> tokens;
+
+ bool speaker_turn_next;
};
// medium
whisper_coreml_context * ctx_coreml = nullptr;
#endif
+#ifdef WHISPER_USE_OPENVINO
+ whisper_openvino_context * ctx_openvino = nullptr;
+#endif
+
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
if (vocab.is_multilingual()) {
vocab.token_eot++;
vocab.token_sot++;
- vocab.token_prev++;
+ vocab.token_translate++;
+ vocab.token_transcribe++;
vocab.token_solm++;
+ vocab.token_prev++;
+ vocab.token_nosp++;
vocab.token_not++;
vocab.token_beg++;
}
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
+ } else if (i == vocab.token_solm) {
+ word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
+ } else if (i == vocab.token_nosp) {
+ word = "[_NOSP_]";
} else if (i == vocab.token_not) {
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
const bool use_coreml = wstate.ctx_coreml != nullptr;
#endif
- if (!use_coreml) {
+#ifndef WHISPER_USE_OPENVINO
+ const bool use_openvino = false;
+#else
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
+#endif
+
+ if (!use_coreml && !use_openvino) {
// convolution + gelu
{
wstate.use_buf(ctx0, 1);
}
}
#ifdef WHISPER_USE_COREML
- else
- {
+ else if (use_coreml) {
wstate.use_buf(ctx0, -1);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
}
#endif
+#ifdef WHISPER_USE_OPENVINO
+ else if (use_openvino) {
+ wstate.use_buf(ctx0, -1);
+
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+
+ if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
+ return false;
+ }
+ }
+#endif
// cur
//{
}
#endif
+#ifdef WHISPER_USE_OPENVINO
+// replace .bin with-encoder-openvino.xml
+static std::string whisper_openvino_get_path_encoder(std::string path_bin) {
+ auto pos = path_bin.rfind('.');
+ if (pos != std::string::npos) {
+ path_bin = path_bin.substr(0, pos);
+ }
+
+ path_bin += "-encoder-openvino.xml";
+
+ return path_bin;
+}
+
+static std::string whisper_openvino_get_path_cache(std::string path_bin) {
+ auto pos = path_bin.rfind('.');
+ if (pos != std::string::npos) {
+ path_bin = path_bin.substr(0, pos);
+ }
+
+ path_bin += "-encoder-openvino-cache";
+
+ return path_bin;
+}
+#endif
+
struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_state * state = new whisper_state;
return state;
}
+int whisper_ctx_init_openvino_encoder(
+ struct whisper_context * ctx,
+ const char * model_path,
+ const char * device,
+ const char * cache_dir) {
+#ifndef WHISPER_USE_OPENVINO
+ (void)(ctx);
+ (void)(model_path);
+ (void)(device);
+ (void)(cache_dir);
+
+ return 1;
+#else
+ if (!model_path && ctx->path_model.empty()) {
+ fprintf(stderr, "%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
+ return 1;
+ }
+
+ std::string path_encoder;
+ if (!model_path) {
+ //if model_path is not set, attempt to find it in the same directory as ggml-<model>.bin model
+ path_encoder = whisper_openvino_get_path_encoder(ctx->path_model);
+ } else {
+ path_encoder = model_path;
+ }
+
+ std::string path_cache;
+ if (!cache_dir) {
+ //if cache_dir is not set, set it as a dir residing next to ggml-<model>.bin
+ path_cache = whisper_openvino_get_path_cache(ctx->path_model);
+ } else {
+ path_cache = cache_dir;
+ }
+
+ fprintf(stderr, "%s: loading OpenVINO model from '%s'\n", __func__, path_openvino.c_str());
+ fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
+
+ ctx->state->ctx_openvino = whisper_openvino_init(path_openvino.c_str(), device, path_cache.c_str());
+ if (!ctx->state->ctx_openvino) {
+ fprintf(stderr, "%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_openvino.c_str());
+ return 1;
+ } else {
+ fprintf(stderr, "%s: OpenVINO model loaded\n", __func__);
+ }
+
+ return 0;
+#endif
+}
+
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
}
#endif
+#ifdef WHISPER_USE_OPENVINO
+ if (state->ctx_openvino != nullptr) {
+ whisper_openvino_free(state->ctx_openvino);
+ state->ctx_openvino = nullptr;
+ }
+#endif
+
delete state;
}
}
return ctx->vocab.token_sot;
}
+whisper_token whisper_token_solm(struct whisper_context * ctx) {
+ return ctx->vocab.token_solm;
+}
+
whisper_token whisper_token_prev(struct whisper_context * ctx) {
return ctx->vocab.token_prev;
}
-whisper_token whisper_token_solm(struct whisper_context * ctx) {
- return ctx->vocab.token_solm;
+whisper_token whisper_token_nosp(struct whisper_context * ctx) {
+ return ctx->vocab.token_nosp;
}
whisper_token whisper_token_not(struct whisper_context * ctx) {
return whisper_token_sot(ctx) + 1 + lang_id;
}
-whisper_token whisper_token_translate(void) {
- return whisper_vocab::token_translate;
+whisper_token whisper_token_translate(struct whisper_context * ctx) {
+ return ctx->vocab.token_translate;
}
-whisper_token whisper_token_transcribe(void) {
- return whisper_vocab::token_transcribe;
+whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
+ return ctx->vocab.token_transcribe;
}
void whisper_print_timings(struct whisper_context * ctx) {
#endif
}
+static int whisper_has_openvino(void) {
+#ifdef WHISPER_USE_OPENVINO
+ return 1;
+#else
+ return 0;
+#endif
+}
+
const char * whisper_print_system_info(void) {
static std::string s;
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
+ s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
return s.c_str();
}
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
- /*.strategy =*/ strategy,
-
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
- /*.n_max_text_ctx =*/ 16384,
- /*.offset_ms =*/ 0,
- /*.duration_ms =*/ 0,
-
- /*.translate =*/ false,
- /*.no_context =*/ true,
- /*.single_segment =*/ false,
- /*.print_special =*/ false,
- /*.print_progress =*/ true,
- /*.print_realtime =*/ false,
- /*.print_timestamps =*/ true,
-
- /*.token_timestamps =*/ false,
- /*.thold_pt =*/ 0.01f,
- /*.thold_ptsum =*/ 0.01f,
- /*.max_len =*/ 0,
- /*.split_on_word =*/ false,
- /*.max_tokens =*/ 0,
-
- /*.speed_up =*/ false,
- /*.audio_ctx =*/ 0,
-
- /*.initial_prompt =*/ nullptr,
- /*.prompt_tokens =*/ nullptr,
- /*.prompt_n_tokens =*/ 0,
-
- /*.language =*/ "en",
- /*.detect_language =*/ false,
-
- /*.suppress_blank =*/ true,
+ /*.strategy =*/ strategy,
+
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
+
+ /*.translate =*/ false,
+ /*.no_context =*/ true,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
+
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.split_on_word =*/ false,
+ /*.max_tokens =*/ 0,
+
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
+
+ /*.tdrz_enable =*/ false,
+
+ /*.initial_prompt =*/ nullptr,
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
+
+ /*.language =*/ "en",
+ /*.detect_language =*/ false,
+
+ /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
- /*.temperature =*/ 0.0f,
- /*.max_initial_ts =*/ 1.0f,
- /*.length_penalty =*/ -1.0f,
+ /*.temperature =*/ 0.0f,
+ /*.max_initial_ts =*/ 1.0f,
+ /*.length_penalty =*/ -1.0f,
- /*.temperature_inc =*/ 0.4f,
- /*.entropy_thold =*/ 2.4f,
- /*.logprob_thold =*/ -1.0f,
- /*.no_speech_thold =*/ 0.6f,
+ /*.temperature_inc =*/ 0.4f,
+ /*.entropy_thold =*/ 2.4f,
+ /*.logprob_thold =*/ -1.0f,
+ /*.no_speech_thold =*/ 0.6f,
- /*.greedy =*/ {
+ /*.greedy =*/ {
/*.best_of =*/ -1,
},
float thold_pt,
float thold_ptsum);
-// trim from start (in place)
-static inline void ltrim(std::string &s) {
- s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
- return std::isspace(ch);
- }));
-}
-
-// trim from end (in place)
-static inline void rtrim(std::string &s) {
- s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
- return std::isspace(ch);
- }).base(), s.end());
-}
-
-// trim from both ends (in place)
-static inline void trim(std::string &s) {
- rtrim(s);
- ltrim(s);
-}
-
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
if (!split_on_word) return true;
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
- // split here
- if (split_on_word) {
- trim(text);
- }
-
state.result_all.back().text = std::move(text);
state.result_all.back().t1 = token.t0;
state.result_all.back().tokens.resize(i);
+ state.result_all.back().speaker_turn_next = false;
state.result_all.push_back({});
state.result_all.back().t0 = token.t0;
segment.tokens.begin() + i,
segment.tokens.end());
+ state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
+
acc = 0;
text = "";
}
}
- if (split_on_word) {
- trim(text);
- }
state.result_all.back().text = std::move(text);
return res;
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
- // suppress sot and solm tokens
+ // suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
- logits[vocab.token_solm] = -INFINITY;
+ logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
+
+ // [TDRZ] when tinydiarize is disabled, suppress solm token
+ if (params.tdrz_enable == false) {
+ logits[vocab.token_solm] = -INFINITY;
+ }
// suppress task tokens
logits[vocab.token_translate] = -INFINITY;
state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
- prompt_init.push_back(whisper_token_translate());
+ prompt_init.push_back(whisper_token_translate(ctx));
} else {
- prompt_init.push_back(whisper_token_transcribe());
+ prompt_init.push_back(whisper_token_transcribe(ctx));
}
}
prompt_past.push_back(tokens_cur[i].id);
}
- // store the text from this iteration
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text;
+ bool speaker_turn_next = false;
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
- if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
- } else {
+ if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
+ // [TDRZ] record if speaker turn was predicted after current segment
+ if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
+ speaker_turn_next = true;
+ }
+
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
- result_all.push_back({ tt0, tt1, text, {} });
+ result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
i--;
t0 = t1;
i0 = i + 1;
+ speaker_turn_next = false;
}
}
}
}
- result_all.push_back({ tt0, tt1, text, {} });
+ result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
return ctx->state->result_all[i_segment].t1;
}
+bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
+ return ctx->state->result_all[i_segment].speaker_turn_next;
+}
+
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].text.c_str();
}
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
+ // Given a context, enable use of OpenVINO for encode inference.
+ // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
+ // the path will be generated from the ggml model path that was passed
+ // in to whisper_init_from_file. For example, if 'path_model' was
+ // "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
+ // assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
+ // device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
+ // cache_dir: Optional cache directory that can speed up init time, especially for
+ // GPU, by caching compiled 'blobs' there.
+ // Set to nullptr if not used.
+ // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
+ WHISPER_API int whisper_ctx_init_openvino_encoder(
+ struct whisper_context * ctx,
+ const char * model_path,
+ const char * device,
+ const char * cache_dir);
+
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
// Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
- WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
- WHISPER_API whisper_token whisper_token_translate (void);
- WHISPER_API whisper_token whisper_token_transcribe(void);
+ WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
// Performance information from the default state.
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
+ // [EXPERIMENTAL] [TDRZ] tinydiarize
+ bool tdrz_enable; // enable tinydiarize speaker turn detection
+
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const char * initial_prompt;
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
+ // Get whether the next segment is predicted as a speaker turn
+ WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
+
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Temporary helpers needed for exposing ggml interface
- WHISPER_API int whisper_bench_memcpy(int n_threads);
- WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
- WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
+ WHISPER_API int whisper_bench_memcpy (int n_threads);
+ WHISPER_API const char * whisper_bench_memcpy_str (int n_threads);
+ WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
#ifdef __cplusplus