rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
-#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
+#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37
extern VALUE cParams;
extern VALUE cVADParams;
static ID id_print_progress;
static ID id_print_realtime;
static ID id_print_timestamps;
+static ID id_carry_initial_prompt;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
{
BOOL_PARAMS_GETTER(self, print_timestamps)
}
+
+/*
+ * call-seq:
+ * carry_initial_prompt -> true or false
+ */
+static VALUE
+ruby_whisper_params_get_carry_initial_prompt(VALUE self)
+{
+ BOOL_PARAMS_GETTER(self, carry_initial_prompt)
+}
+
+/*
+ * call-seq:
+ * carry_initial_prompt = bool -> bool
+ */
+static VALUE
+ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
+{
+ BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
+}
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
SET_PARAM_IF_SAME(max_len)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
+ SET_PARAM_IF_SAME(carry_initial_prompt)
SET_PARAM_IF_SAME(offset)
SET_PARAM_IF_SAME(duration)
SET_PARAM_IF_SAME(max_text_tokens)
DEFINE_PARAM(max_len, 11)
DEFINE_PARAM(split_on_word, 12)
DEFINE_PARAM(initial_prompt, 13)
- DEFINE_PARAM(diarize, 14)
- DEFINE_PARAM(offset, 15)
- DEFINE_PARAM(duration, 16)
- DEFINE_PARAM(max_text_tokens, 17)
- DEFINE_PARAM(temperature, 18)
- DEFINE_PARAM(max_initial_ts, 19)
- DEFINE_PARAM(length_penalty, 20)
- DEFINE_PARAM(temperature_inc, 21)
- DEFINE_PARAM(entropy_thold, 22)
- DEFINE_PARAM(logprob_thold, 23)
- DEFINE_PARAM(no_speech_thold, 24)
- DEFINE_PARAM(new_segment_callback, 25)
- DEFINE_PARAM(new_segment_callback_user_data, 26)
- DEFINE_PARAM(progress_callback, 27)
- DEFINE_PARAM(progress_callback_user_data, 28)
- DEFINE_PARAM(encoder_begin_callback, 29)
- DEFINE_PARAM(encoder_begin_callback_user_data, 30)
- DEFINE_PARAM(abort_callback, 31)
- DEFINE_PARAM(abort_callback_user_data, 32)
- DEFINE_PARAM(vad, 33)
- DEFINE_PARAM(vad_model_path, 34)
- DEFINE_PARAM(vad_params, 35)
+ DEFINE_PARAM(carry_initial_prompt, 14)
+ DEFINE_PARAM(diarize, 15)
+ DEFINE_PARAM(offset, 16)
+ DEFINE_PARAM(duration, 17)
+ DEFINE_PARAM(max_text_tokens, 18)
+ DEFINE_PARAM(temperature, 19)
+ DEFINE_PARAM(max_initial_ts, 20)
+ DEFINE_PARAM(length_penalty, 21)
+ DEFINE_PARAM(temperature_inc, 22)
+ DEFINE_PARAM(entropy_thold, 23)
+ DEFINE_PARAM(logprob_thold, 24)
+ DEFINE_PARAM(no_speech_thold, 25)
+ DEFINE_PARAM(new_segment_callback, 26)
+ DEFINE_PARAM(new_segment_callback_user_data, 27)
+ DEFINE_PARAM(progress_callback, 28)
+ DEFINE_PARAM(progress_callback_user_data, 29)
+ DEFINE_PARAM(encoder_begin_callback, 30)
+ DEFINE_PARAM(encoder_begin_callback_user_data, 31)
+ DEFINE_PARAM(abort_callback, 32)
+ DEFINE_PARAM(abort_callback_user_data, 33)
+ DEFINE_PARAM(vad, 34)
+ DEFINE_PARAM(vad_model_path, 35)
+ DEFINE_PARAM(vad_params, 36)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
#include "grammar-parser.h"
#include <cmath>
+#include <algorithm>
#include <fstream>
#include <cstdio>
#include <string>
bool use_gpu = true;
bool flash_attn = true;
bool suppress_nst = false;
+ bool carry_initial_prompt = false;
std::string language = "en";
std::string prompt;
exit(0);
}
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
- else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
- else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); }
- else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); }
- else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); }
- else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); }
- else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); }
- else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); }
- else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); }
- else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); }
- else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); }
- else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); }
- else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); }
- else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); }
- else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); }
- else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); }
- else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); }
- else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
- else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
- else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
- else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
- else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
- else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
- else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
- else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
- else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
- else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
- else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
- else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
- else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
- else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
- else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
- else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
- else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
- else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
- else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
- else if ( arg == "--print-confidence"){ params.print_confidence= true; }
- else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
- else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
- else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
- else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
- else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
- else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
- else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
- else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
- else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; }
- else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
- else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
- else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
- else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
- else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
- else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; }
- else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
- else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
- else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
+ else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); }
+ else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); }
+ else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); }
+ else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); }
+ else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); }
+ else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); }
+ else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); }
+ else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); }
+ else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); }
+ else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); }
+ else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); }
+ else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); }
+ else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); }
+ else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); }
+ else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); }
+ else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
+ else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
+ else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
+ else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
+ else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
+ else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
+ else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
+ else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
+ else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
+ else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
+ else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
+ else if (arg == "-ojf" || arg == "--output-json-full") { params.output_jsn_full = params.output_jsn = true; }
+ else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
+ else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
+ else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
+ else if ( arg == "--print-confidence") { params.print_confidence= true; }
+ else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
+ else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
+ else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
+ else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
+ else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
+ else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; }
+ else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
+ else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
+ else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
+ else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; }
+ else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
+ else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
+ else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
+ else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
+ else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; }
+ else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
+ else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
+ else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
// Voice Activity Detection (VAD)
else if ( arg == "--vad") { params.vad = true; }
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
- fprintf(stderr, " -h, --help [default] show this help message and exit\n");
- fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
- fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
- fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
- fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
- fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
- fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
- fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
- fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
- fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
- fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
- fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
- fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
- fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
- fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
- fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
- fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
- fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
- fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
- fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
- fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
- fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
- fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
- fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
- fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
- fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
- fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
- fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
- fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
- fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
- fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
- fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
- fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
- fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
- fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
- fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
- fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false");
- fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
- fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
- fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
- fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
- fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
- fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", "");
- fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
- fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
- fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
- fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
- fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
- fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
- fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
- fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
- fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
- fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
- fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
+ fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
+ fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
+ fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
+ fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
+ fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
+ fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
+ fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
+ fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
+ fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
+ fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
+ fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
+ fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
+ fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
+ fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
+ fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
+ fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
+ fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
+ fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
+ fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
+ fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
+ fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
+ fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
+ fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
+ fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false");
+ fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
+ fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
+ fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false");
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", "");
+ fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
+ fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
+ fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
+ fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
+ fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
+ fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
+ fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
+ fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
+ fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
+ fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
+ fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
// Voice Activity Detection (VAD) parameters
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
- const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
+ const int n_colors = (int) k_colors.size();
+ int raw_col = (int) (std::pow(p, 3)*float(n_colors));
+ if (raw_col < 0) raw_col = 0;
+ if (raw_col > n_colors - 1) raw_col = n_colors - 1;
+ const int col = raw_col;
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
- wparams.initial_prompt = params.prompt.c_str();
+ wparams.initial_prompt = params.prompt.c_str();
+ wparams.carry_initial_prompt = params.carry_initial_prompt;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
} while (0)
#define WHISPER_MAX_DECODERS 8
+
+// temperature below which we condition on past text history
+static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f;
+
#define WHISPER_MAX_NODES 4096
static std::string format(const char * fmt, ...) {
std::vector<float> logits;
std::vector<whisper_segment> result_all;
- std::vector<whisper_token> prompt_past;
+
+ // prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1)
+ std::vector<whisper_token> prompt_past0; // static carried initial prompt (if enabled)
+ std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
int lang_id = 0; // english by default
/* suppress_regex =*/ nullptr,
- /*.initial_prompt =*/ nullptr,
- /*.prompt_tokens =*/ nullptr,
- /*.prompt_n_tokens =*/ 0,
+ /*.initial_prompt =*/ nullptr,
+ /*.carry_initial_prompt =*/ false,
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.detect_language =*/ false,
decoder.rng = std::mt19937(j);
}
- // the accumulated text context so far
- auto & prompt_past = state->prompt_past;
+ // the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1)
+ auto & prompt_past0 = state->prompt_past0;
+ auto & prompt_past1 = state->prompt_past1;
if (params.no_context) {
- prompt_past.clear();
+ prompt_past0.clear();
+ prompt_past1.clear();
}
+ // calculate the maximum context budget for prompt history
+ const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2);
+
// prepare prompt
{
std::vector<whisper_token> prompt_tokens;
- // initial prompt
+ // tokenize the initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(1024);
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}
-
- // prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
- // parse tokens from the pointer
- for (int i = 0; i < params.prompt_n_tokens; i++) {
- prompt_past.push_back(params.prompt_tokens[i]);
+ if (params.carry_initial_prompt) {
+ if (prompt_past0.empty()) {
+ const int max_tokens = std::max(1, max_prompt_ctx - 1);
+
+ if (params.prompt_n_tokens > max_tokens) {
+ WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n",
+ __func__, params.prompt_n_tokens, max_tokens);
+ }
+
+ const int n_tokens = std::min(params.prompt_n_tokens, max_tokens);
+ prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens);
+ }
+ } else {
+ for (int i = 0; i < params.prompt_n_tokens; ++i) {
+ prompt_past1.push_back(params.prompt_tokens[i]);
+ }
+ std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end());
}
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
- prompt_past.clear();
+ prompt_past0.clear();
+ prompt_past1.clear();
}
int best_decoder_id = 0;
{
prompt.clear();
- // if we have already generated some text, use it as a prompt to condition the next generation
- if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
- int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
+ if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) {
+ const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty();
+ const bool can_take1 = !prompt_past1.empty();
+
+ if (max_prompt_ctx > 0 && (can_take0 || can_take1)) {
+ // Always start with previous token marker to connect continuity
+ prompt.push_back(whisper_token_prev(ctx));
- prompt = { whisper_token_prev(ctx) };
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+ // Take static tokens (initial prompt) first
+ int n_take0 = 0;
+ if (can_take0) {
+ n_take0 = prompt_past0.size();
+ prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end());
+ }
+
+ // Fill remaining budget with dynamic tokens (rolling context)
+ const int n_take1 = std::min<int>(max_prompt_ctx - n_take0 - 1, prompt_past1.size());
+ prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end());
+ }
}
// init new transcription with sot, language (opt) and task tokens
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
- // update prompt_past
- prompt_past.clear();
- if (prompt.front() == whisper_token_prev(ctx)) {
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
+ // update prompt_past1
+ prompt_past1.clear();
+ if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) {
+ prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}
- for (int i = 0; i < result_len && !is_no_speech; ++i) {
- prompt_past.push_back(tokens_cur[i].id);
+ // Add newly decoded tokens to the rolling context
+ if (!is_no_speech) {
+ for (int i = 0; i < result_len; ++i) {
+ prompt_past1.push_back(tokens_cur[i].id);
+ }
}
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {