return std::string(buf);
}
+int timestamp_to_sample(int64_t t, int n_samples) {
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
// command-line parameters
struct whisper_params {
- int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
float word_thold = 0.01f;
- bool verbose = false;
- bool translate = false;
- bool output_txt = false;
- bool output_vtt = false;
- bool output_srt = false;
- bool output_wts = false;
- bool print_special_tokens = false;
- bool print_colors = false;
- bool no_timestamps = false;
+ bool speed_up = false;
+ bool translate = false;
+ bool diarize = false;
+ bool output_txt = false;
+ bool output_vtt = false;
+ bool output_srt = false;
+ bool output_wts = false;
+ bool print_special = false;
+ bool print_colors = false;
+ bool no_timestamps = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
continue;
}
- if (arg == "-s" || arg == "--seed") {
- params.seed = std::stoi(argv[++i]);
- } else if (arg == "-t" || arg == "--threads") {
- params.n_threads = std::stoi(argv[++i]);
- } else if (arg == "-p" || arg == "--processors") {
- params.n_processors = std::stoi(argv[++i]);
- } else if (arg == "-ot" || arg == "--offset-t") {
- params.offset_t_ms = std::stoi(argv[++i]);
- } else if (arg == "-on" || arg == "--offset-n") {
- params.offset_n = std::stoi(argv[++i]);
- } else if (arg == "-d" || arg == "--duration") {
- params.duration_ms = std::stoi(argv[++i]);
- } else if (arg == "-mc" || arg == "--max-context") {
- params.max_context = std::stoi(argv[++i]);
- } else if (arg == "-ml" || arg == "--max-len") {
- params.max_len = std::stoi(argv[++i]);
- } else if (arg == "-wt" || arg == "--word-thold") {
- params.word_thold = std::stof(argv[++i]);
- } else if (arg == "-v" || arg == "--verbose") {
- params.verbose = true;
- } else if (arg == "--translate") {
- params.translate = true;
- } else if (arg == "-l" || arg == "--language") {
- params.language = argv[++i];
- if (whisper_lang_id(params.language.c_str()) == -1) {
- fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
- whisper_print_usage(argc, argv, params);
- exit(0);
- }
- } else if (arg == "-otxt" || arg == "--output-txt") {
- params.output_txt = true;
- } else if (arg == "-ovtt" || arg == "--output-vtt") {
- params.output_vtt = true;
- } else if (arg == "-osrt" || arg == "--output-srt") {
- params.output_srt = true;
- } else if (arg == "-owts" || arg == "--output-words") {
- params.output_wts = true;
- } else if (arg == "-ps" || arg == "--print_special") {
- params.print_special_tokens = true;
- } else if (arg == "-pc" || arg == "--print_colors") {
- params.print_colors = true;
- } else if (arg == "-nt" || arg == "--no_timestamps") {
- params.no_timestamps = true;
- } else if (arg == "-m" || arg == "--model") {
- params.model = argv[++i];
- } else if (arg == "-f" || arg == "--file") {
- params.fname_inp.push_back(argv[++i]);
- } else if (arg == "-h" || arg == "--help") {
+ if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
- } else {
+ }
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
+ else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
+ else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
+ else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
+ else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
+ else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
+ else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
+ else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
+ else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
+ else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
+ else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
+ else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
+ else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
+ else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
+ else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
+ else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
+ else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
+ else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
+ else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
- fprintf(stderr, " -h, --help show this help message and exit\n");
- fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
- fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
- fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
- fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
- fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
- fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
- fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
- fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
- fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
- fprintf(stderr, " -v, --verbose verbose output\n");
- fprintf(stderr, " --translate translate from source language to english\n");
- fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
- fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
- fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
- fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
- fprintf(stderr, " -ps, --print_special print special tokens\n");
- fprintf(stderr, " -pc, --print_colors print colors\n");
- fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
- fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
- fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
- fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n");
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
+ fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
+ fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
+ fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
+ fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
+struct whisper_print_user_data {
+ const whisper_params * params;
+
+ const std::vector<std::vector<float>> * pcmf32s;
+};
+
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
- const whisper_params & params = *(whisper_params *) user_data;
+ const auto & params = *((whisper_print_user_data *) user_data)->params;
+ const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special_tokens == false) {
+ if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+ std::string speaker = "";
+
+ if (params.diarize && pcmf32s.size() == 2) {
+ const int64_t n_samples = pcmf32s[0].size();
+
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for (int64_t j = is0; j < is1; j++) {
+ energy0 += fabs(pcmf32s[0][j]);
+ energy1 += fabs(pcmf32s[1][j]);
+ }
+
+ if (energy0 > 1.1*energy1) {
+ speaker = "(speaker 0)";
+ } else if (energy1 > 1.1*energy0) {
+ speaker = "(speaker 1)";
+ } else {
+ speaker = "(speaker ?)";
+ }
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special_tokens == false) {
+ if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+ printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+ printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
}
}
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
- return 9;
+ return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
ncnt += txt.size();
}
- ::replace_all(txt_bg, "'", "’");
+ ::replace_all(txt_bg, "'", "\u2019");
::replace_all(txt_bg, "\"", "\\\"");
- ::replace_all(txt_fg, "'", "’");
+ ::replace_all(txt_fg, "'", "\u2019");
::replace_all(txt_fg, "\"", "\\\"");
}
return 1;
}
- if (params.seed < 0) {
- params.seed = time(NULL);
- }
-
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params);
return 2;
}
+ if (whisper_lang_id(params.language.c_str()) == -1) {
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
+ std::vector<float> pcmf32; // mono-channel F32 PCM
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
+
// WAV input
- std::vector<float> pcmf32;
{
drwav wav;
-
+ std::vector<uint8_t> wav_data; // used for pipe input from stdin
+
if (fname_inp == "-") {
- std::vector<uint8_t> wav_data;
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
- if (n == 0)
- {
+ if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
- if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false)
- {
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
+
+ fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
- return 4;
+ return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
- return 5;
+ return 6;
+ }
+
+ if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
+ fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
+ return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
- return 6;
+ return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
- return 7;
+ return 9;
}
- int n = wav.totalPCMFrameCount;
+ const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
+
+ if (params.diarize) {
+ // convert to stereo, float
+ pcmf32s.resize(2);
+
+ pcmf32s[0].resize(n);
+ pcmf32s[1].resize(n);
+ for (int i = 0; i < n; i++) {
+ pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+ pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+ }
+ }
}
// print system information
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
- wparams.print_realtime = false;
- wparams.print_progress = false;
- wparams.print_timestamps = !params.no_timestamps;
- wparams.print_special_tokens = params.print_special_tokens;
- wparams.translate = params.translate;
- wparams.language = params.language.c_str();
- wparams.n_threads = params.n_threads;
- wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
- wparams.offset_ms = params.offset_t_ms;
- wparams.duration_ms = params.duration_ms;
-
- wparams.token_timestamps = params.output_wts || params.max_len > 0;
- wparams.thold_pt = params.word_thold;
- wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+ wparams.print_realtime = false;
+ wparams.print_progress = false;
+ wparams.print_timestamps = !params.no_timestamps;
+ wparams.print_special = params.print_special;
+ wparams.translate = params.translate;
+ wparams.language = params.language.c_str();
+ wparams.n_threads = params.n_threads;
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
+ wparams.offset_ms = params.offset_t_ms;
+ wparams.duration_ms = params.duration_ms;
+
+ wparams.token_timestamps = params.output_wts || params.max_len > 0;
+ wparams.thold_pt = params.word_thold;
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
+ wparams.speed_up = params.speed_up;
+
+ whisper_print_user_data user_data = { ¶ms, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
- wparams.new_segment_callback_user_data = ¶ms;
+ wparams.new_segment_callback_user_data = &user_data;
+ }
+
+ // example for abort mechanism
+ // in this example, we do not abort the processing, but we could if the flag is set to true
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
+ {
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+ wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
+ bool is_aborted = *(bool*)user_data;
+ return !is_aborted;
+ };
+ wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
- return 8;
+ return 10;
}
}
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
+
+ // [EXPERIMENTAL] speed-up techniques
+ int32_t exp_n_audio_ctx; // 0 - use default
};
// load the model from a ggml file
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
-
- // this is the total memory required to run the inference
- const size_t mem_required =
- wctx.buf_model->size() +
- wctx.buf_memory.size() +
- wctx.buf_compute.size() +
- wctx.buf_compute_layer.size();
-
- fprintf(stderr, "%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
// load mel filters
}
}
+ {
+ // this is the total memory required to run the inference
+ const size_t mem_required =
+ wctx.buf_model->size() +
+ wctx.buf_memory.size() +
+ wctx.buf_compute.size() +
+ wctx.buf_compute_layer.size();
+
+ fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+ }
+
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
-
size_t ctx_size = 0;
size_t ctx_mem_size = 0;
const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer;
- const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
- fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+ fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer;
- const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
// key/value memory for the cross-attention layer
{
- const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
- fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
+ fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// load weights
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
- //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
- fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+ fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams;
- const int n_ctx = hparams.n_audio_ctx;
+ const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
- const int N = n_ctx;
-
const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels);
cur = ggml_gelu(ctx0, cur);
}
- cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+ // ===================================================================
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
+ //static int iter = -1;
+ //const int n_iter = 1500/n_ctx;
+
+ //iter = (iter + 1) % n_iter;
+
+ //if (iter == 0) {
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+ //}
+
+ static int iter = 0;
+
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+
+ cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+ // ===================================================================
+
+ // original:
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur;
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
- n_state/n_head, n_head, N),
+ n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
- n_state/n_head, n_head, N),
+ n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
cur = ggml_cpy(ctxL,
KQV_merged,
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
}
// projection
Vcross),
Vcross);
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
const int n_layer = hparams.n_text_layer;
const int N = n_tokens;
- const int M = hparams.n_audio_ctx;
+ const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
- const float * probs) {
+ const float * probs,
+ bool force_timestamp,
+ bool is_initial) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
max_tx = std::max(max_tx, probs_id[i].first);
}
- for (int i = vocab.token_beg; i < n_logits; i++) {
+ const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
+ const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
+
+ // the initial timestamp cannot be larger than 100
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+ if (is_initial) {
+ for (int i = i0; i < n_logits; ++ i) {
+ probs_id[i].first = -INFINITY;
+ }
+ }
+
+ for (int i = vocab.token_beg; i < i1; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
- if (sum_ts > max_tx) {
+ if (sum_ts > max_tx || force_timestamp) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
return result;
}
-// samples only from the timestamps tokens
-static whisper_vocab::id whisper_sample_timestamp(
- const whisper_vocab & vocab,
- const float * probs) {
- int n_logits = vocab.id_to_token.size();
-
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
- probs_id.reserve(n_logits);
-
- for (int i = vocab.token_beg + 1; i < n_logits; i++) {
- probs_id.push_back(std::make_pair(probs[i], i));
- }
-
- const int top_k = 10;
-
- // find the top K tokens
- std::partial_sort(
- probs_id.begin(),
- probs_id.begin() + top_k, probs_id.end(),
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
- return a.first > b.first;
- });
-
- probs_id.resize(top_k);
-
- //printf("\n");
- //for (int i = 0; i < (int) probs_id.size(); i++) {
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
- //}
-
- return probs_id[0].second;
-}
-
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
const int n_mel,
const int n_threads,
const whisper_filters & filters,
+ const bool speed_up,
whisper_mel & mel) {
// Hanning window
mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len);
- const int n_fft = 1 + fft_size/2;
+ const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
//}
}
+ if (speed_up) {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for (int j = 0; j < n_fft; j++) {
+ fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
+ }
+ }
+
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
void whisper_free(struct whisper_context * ctx) {
if (ctx) {
+ if (ctx->model.ctx) {
+ ggml_free(ctx->model.ctx);
+ }
+ if (ctx->model.ctx_mem) {
+ ggml_free(ctx->model.ctx_mem);
+ }
if (ctx->buf_model) {
delete ctx->buf_model;
}
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
- if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
+ if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+ return -1;
+ }
+
+ ctx->t_mel_us = ggml_time_us() - t_start_us;
+
+ return 0;
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
- // TODO: simplify
- auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
-whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
+struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
const int64_t t_start_sample_us = ggml_time_us();
- // TODO: simplify
- auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return ctx->vocab.token_beg;
}
-whisper_token whisper_token_translate() {
+whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
}
-whisper_token whisper_token_transcribe() {
+whisper_token whisper_token_transcribe(void) {
return whisper_vocab::token_transcribe;
}
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
+void whisper_reset_timings(struct whisper_context * ctx) {
+ ctx->t_sample_us = 0;
+ ctx->t_encode_us = 0;
+ ctx->t_decode_us = 0;
+}
+
+const char * whisper_print_system_info(void) {
+ static std::string s;
+
+ s = "";
+ s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
+ s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
+ s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
+ s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
+ s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
+ s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
+ s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
+
+ return s.c_str();
+}
+
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result = {
- /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
+ /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
+
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
- /*.n_max_text_ctx =*/ 16384,
- /*.offset_ms =*/ 0,
- /*.duration_ms =*/ 0,
+ /*.translate =*/ false,
+ /*.no_context =*/ false,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
- /*.translate =*/ false,
- /*.no_context =*/ false,
- /*.print_special_tokens =*/ false,
- /*.print_progress =*/ true,
- /*.print_realtime =*/ false,
- /*.print_timestamps =*/ true,
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.max_tokens =*/ 0,
- /*.token_timestamps =*/ false,
- /*.thold_pt =*/ 0.01f,
- /*.thold_ptsum =*/ 0.01f,
- /*.max_len =*/ 0,
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
- /*.language =*/ "en",
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
- /*.greedy =*/ {
+ /*.language =*/ "en",
+
+ /*.greedy =*/ {
/*.n_past =*/ 0,
},
- /*.beam_search =*/ {
+ /*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
- /*.new_segment_callback =*/ nullptr,
+ /*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
+
+ /*.encoder_begin_callback =*/ nullptr,
+ /*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result = {
- /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
+ /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
- /*.n_max_text_ctx =*/ 16384,
- /*.offset_ms =*/ 0,
- /*.duration_ms =*/ 0,
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
- /*.translate =*/ false,
- /*.no_context =*/ false,
- /*.print_special_tokens =*/ false,
- /*.print_progress =*/ true,
- /*.print_realtime =*/ false,
- /*.print_timestamps =*/ true,
+ /*.translate =*/ false,
+ /*.no_context =*/ false,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
- /*.token_timestamps =*/ false,
- /*.thold_pt =*/ 0.01f,
- /*.thold_ptsum =*/ 0.01f,
- /*.max_len =*/ 0,
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.max_tokens =*/ 0,
- /*.language =*/ "en",
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
- /*.greedy =*/ {
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
+
+ /*.language =*/ "en",
+
+ /*.greedy =*/ {
/*.n_past =*/ -1,
},
- /*.beam_search =*/ {
+ /*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
- /*.new_segment_callback =*/ nullptr,
+ /*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
+
+ /*.encoder_begin_callback =*/ nullptr,
+ /*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
}
result_all.clear();
// compute log mel spectrogram
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
- return -1;
+ if (params.speed_up) {
+ if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+ return -1;
+ }
+ } else {
+ if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+ return -1;
+ }
}
if (params.token_timestamps) {
prompt_past.clear();
}
+ // prepend the prompt tokens to the prompt_past
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
+ // parse tokens from the pointer
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
+ prompt_past.push_back(params.prompt_tokens[i]);
+ }
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
+ }
+
+ // overwrite audio_ctx
+ ctx->exp_n_audio_ctx = params.audio_ctx;
+
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
break;
}
+ if (params.encoder_begin_callback) {
+ if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+ fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
+ break;
+ }
+ }
+
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
- bool done = false;
int seek_delta = 100*WHISPER_CHUNK_SIZE;
// print the prompt
int result_len = 0;
tokens_cur.clear();
- for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
+ bool failed = false;
+
+ for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return 8;
// feel free to experiment!
//
{
- auto token = whisper_sample_best(ctx);
-
- if (i == 0) {
- token.tid = whisper_token_beg(ctx);
- }
+ const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
- seek_delta = 2*(token.id - whisper_token_beg(ctx));
+ const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+
+ // do not allow to go back in time
+ if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
+ seek_delta > seek_delta_new && result_len < i) {
+ break;
+ }
+
+ seek_delta = seek_delta_new;
result_len = i + 1;
}
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
- // printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
+ // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
- if (token.id == whisper_token_eot(ctx)) {
+ if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
- // TODO: figure out how to resolve this
- fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
+ failed = true;
+ break;
}
}
+
+ if (params.single_segment) {
+ result_len = i + 1;
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
+ }
+
break;
}
}
}
- if (done) {
+ // sometimes, the decoding can get stuck in a repetition loop
+ // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
+ // the sliding window by 1 second
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
+ failed = true;
break;
}
}
+ if (failed) {
+ fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
+ seek += 100;
+ continue;
+ }
+
// shrink down to result_len
tokens_cur.resize(result_len);
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
- if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
+ if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
- if (tokens_cur[i].id > whisper_token_beg(ctx)) {
+ if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
if (params.print_realtime) {
if (params.print_timestamps) {
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
- result_all.push_back({ t0, t1, text, {} });
+ result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
if (!text.empty()) {
const auto t1 = seek + seek_delta;
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
if (params.print_realtime) {
if (params.print_timestamps) {
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
- result_all.push_back({ t0, t1, text, {} });
+ result_all.push_back({ tt0, tt1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
struct whisper_full_params params,
const float * samples,
int n_samples,
- const int n_processors) {
+ int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
// key/value memory for the cross-attention layer
{
- const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
-
- const size_t memory_size =
- ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
- ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
}
}
return ctx->result_all[i_segment].tokens[i_token].p;
}
-const char * whisper_print_system_info() {
- static std::string s;
-
- s = "";
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
-
- return s.c_str();
-}
-
// =================================================================================================
//
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
- const int s0 = timestamp_to_sample(t0, n_samples);
- const int s1 = timestamp_to_sample(t1, n_samples);
-
const int n = tokens.size();
if (n == 0) {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
- float p; // probability of the token
- float pt; // probability of the timestamp token
- float ptsum; // sum of probabilities of all timestamp tokens
+ float p; // probability of the token
+ float pt; // probability of the timestamp token
+ float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
- int64_t t0; // start time of the token
- int64_t t1; // end time of the token
+ int64_t t0; // start time of the token
+ int64_t t1; // end time of the token
- float vlen; // voice length of the token
+ float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx,
- const float * samples,
- int n_samples,
- int n_threads);
+ const float * samples,
+ int n_samples,
+ int n_threads);
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// Returns 0 on success
WHISPER_API int whisper_set_mel(
struct whisper_context * ctx,
- const float * data,
- int n_len,
- int n_mel);
+ const float * data,
+ int n_len,
+ int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns 0 on success
WHISPER_API int whisper_encode(
struct whisper_context * ctx,
- int offset,
- int n_threads);
+ int offset,
+ int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// Returns 0 on success
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
- const whisper_token * tokens,
- int n_tokens,
- int n_past,
- int n_threads);
+ const whisper_token * tokens,
+ int n_tokens,
+ int n_past,
+ int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
- WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
+ WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Return the id of the specified language, returns -1 if not found
WHISPER_API int whisper_lang_id(const char * lang);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
// Task tokens
- WHISPER_API whisper_token whisper_token_translate ();
- WHISPER_API whisper_token whisper_token_transcribe();
+ WHISPER_API whisper_token whisper_token_translate (void);
+ WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
+ WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
+
+ // Print system information
+ WHISPER_API const char * whisper_print_system_info(void);
////////////////////////////////////////////////////////////////////////////
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
+ // Encoder begin callback
+ // If not NULL, called before the encoder starts
+ // If it returns false, the computation is aborted
+ typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+
+ // Parameters for the whisper_full() function
+ // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
+ // whisper_full_default_params()
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
- int offset_ms; // start offset in ms
- int duration_ms; // audio duration to process in ms
+ int offset_ms; // start offset in ms
+ int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
- bool print_special_tokens;
+ bool single_segment; // force single segment output (useful for streaming)
+ bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
+ int max_tokens; // max tokens per segment (0 = no limit)
+
+ // [EXPERIMENTAL] speed-up techniques
+ bool speed_up; // speed-up the audio by 2x using Phase Vocoder
+ int audio_ctx; // overwrite the audio context size (0 = use default)
+
+ // tokens to provide the whisper model as initial prompt
+ // these are prepended to any existing text context from a previous call
+ const whisper_token * prompt_tokens;
+ int prompt_n_tokens;
const char * language;
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
+
+ whisper_encoder_begin_callback encoder_begin_callback;
+ void * encoder_begin_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full(
- struct whisper_context * ctx,
- struct whisper_full_params params,
- const float * samples,
- int n_samples);
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel(
- struct whisper_context * ctx,
- struct whisper_full_params params,
- const float * samples,
- int n_samples,
- const int n_processors);
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples,
+ int n_processors);
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
- // Print system information
- WHISPER_API const char * whisper_print_system_info();
-
#ifdef __cplusplus
}
#endif
// system info
//
+int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_neon(void);
#include <stdio.h>
#if defined _MSC_VER || defined(__MINGW32__)
+
+#if !defined(__MINGW32__)
#include <Windows.h>
+#else
+// ref: https://github.com/ggerganov/whisper.cpp/issues/168
+#include <windows.h>
+#include <errno.h>
+#endif
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
- HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL);
+ HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
sumf = _mm_cvtss_f32(r1);
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#elif defined(__AVX__)
+ // AVX 256-bit
+ const int n32 = (n & ~31);
+
+ __m256 sum0 = _mm256_setzero_ps();
+ __m256 sum1 = _mm256_setzero_ps();
+ __m256 sum2 = _mm256_setzero_ps();
+ __m256 sum3 = _mm256_setzero_ps();
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_loadu_ps(x + i + 0);
+ x1 = _mm256_loadu_ps(x + i + 8);
+ x2 = _mm256_loadu_ps(x + i + 16);
+ x3 = _mm256_loadu_ps(x + i + 24);
+
+ y0 = _mm256_loadu_ps(y + i + 0);
+ y1 = _mm256_loadu_ps(y + i + 8);
+ y2 = _mm256_loadu_ps(y + i + 16);
+ y3 = _mm256_loadu_ps(y + i + 24);
+
+ sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+ sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+ sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+ sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+ }
+
+ sum0 = _mm256_add_ps(sum0, sum1);
+ sum2 = _mm256_add_ps(sum2, sum3);
+ sum0 = _mm256_add_ps(sum0, sum2);
+
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+ sumf = _mm_cvtss_f32(r1);
+
// leftovers
for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i];
sumf = _mm_cvtss_f32(r1);
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ //GGML_ASSERT(false);
+ sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+ }
+#elif defined(__AVX__)
+ // AVX 256-bit
+ const int n32 = (n & ~31);
+
+ __m256 sum0 = _mm256_setzero_ps();
+ __m256 sum1 = _mm256_setzero_ps();
+ __m256 sum2 = _mm256_setzero_ps();
+ __m256 sum3 = _mm256_setzero_ps();
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+ sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+ sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+ sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+ sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+ }
+
+ const __m256 sum01 = _mm256_add_ps(sum0, sum1);
+ const __m256 sum23 = _mm256_add_ps(sum2, sum3);
+ const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
+
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+ sumf = _mm_cvtss_f32(r1);
+
// leftovers
for (int i = n32; i < n; ++i) {
//GGML_ASSERT(false);
_mm256_storeu_ps(y + i + 24, y3);
}
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#elif defined(__AVX__)
+ // AVX 256-bit
+ const int n32 = (n & ~31);
+
+ const __m256 v4 = _mm256_set1_ps(v);
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_loadu_ps(x + i + 0);
+ x1 = _mm256_loadu_ps(x + i + 8);
+ x2 = _mm256_loadu_ps(x + i + 16);
+ x3 = _mm256_loadu_ps(x + i + 24);
+
+ y0 = _mm256_loadu_ps(y + i + 0);
+ y1 = _mm256_loadu_ps(y + i + 8);
+ y2 = _mm256_loadu_ps(y + i + 16);
+ y3 = _mm256_loadu_ps(y + i + 24);
+
+ y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
+ y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
+ y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
+ y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
+
+ _mm256_storeu_ps(y + i + 0, y0);
+ _mm256_storeu_ps(y + i + 8, y1);
+ _mm256_storeu_ps(y + i + 16, y2);
+ _mm256_storeu_ps(y + i + 24, y3);
+ }
+
// leftovers
for (int i = n32; i < n; ++i) {
y[i] += x[i]*v;
_mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
}
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ GGML_ASSERT(false);
+ y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+ }
+#elif defined(__AVX__)
+ // AVX 256-bit
+ const int n32 = (n & ~31);
+
+ const __m256 v8 = _mm256_set1_ps(v);
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+ y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
+ y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
+ y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
+ y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
+
+ _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+ }
+
// leftovers
for (int i = n32; i < n; ++i) {
GGML_ASSERT(false);
////////////////////////////////////////////////////////////////////////////////
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
int ggml_cpu_has_avx2(void) {
#if defined(__AVX2__)
return 1;