* Allow a regular expression to describe tokens to suppress.
Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens.
Technique inspired by https://github.com/openai/whisper/discussions/1041
Co-authored-by: Georgi Gerganov <redacted>
* Blind change to fix Java test.
---------
Co-authored-by: Georgi Gerganov <redacted>
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;\r
}\r
\r
+ /** Regular expression matching tokens to suppress. */\r
+ public String suppress_regex;\r
+\r
/** Tokens to provide to the whisper decoder as an initial prompt.\r
* These are prepended to any existing text context from a previous call. */\r
public String initial_prompt;\r
"no_context", "single_segment", "no_timestamps",\r
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",\r
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",\r
- "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
+ "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",\r
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",\r
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",\r
"new_segment_callback", "new_segment_callback_user_data",\r
std::string prompt;
std::string context;
std::string grammar;
+
+ // A regular expression that matches tokens to suppress
+ std::string suppress_regex;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
+ else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
+ fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, "\n");
}
wparams.initial_prompt = params.context.data();
+ wparams.suppress_regex = params.suppress_regex.c_str();
+
const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();
#include <cmath>
#include <fstream>
#include <cstdio>
+#include <regex>
#include <string>
#include <thread>
#include <vector>
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
+ // A regular expression that matches tokens to suppress
+ std::string suppress_regex;
+
std::string openvino_encode_device = "CPU";
std::string dtw = "";
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
+ else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
+ fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
+ wparams.suppress_regex = params.suppress_regex.c_str();
+
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
/*.tdrz_enable =*/ false,
+ /* suppress_regex =*/ nullptr,
+
/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
+ // suppress any tokens matching a regular expression
+ // ref: https://github.com/openai/whisper/discussions/1041
+ if (params.suppress_regex != nullptr) {
+ std::regex re(params.suppress_regex);
+ for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
+ if (std::regex_match(token_id.first, re)) {
+ logits[token_id.second] = -INFINITY;
+ }
+ }
+ }
+
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
// [EXPERIMENTAL] [TDRZ] tinydiarize
bool tdrz_enable; // enable tinydiarize speaker turn detection
+ // A regular expression that matches tokens to suppress
+ const char * suppress_regex;
+
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
// use whisper_tokenize() to convert text to tokens