]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : suppress tokens with a regex (whisper/1997)
authorulatekh <redacted>
Tue, 9 Apr 2024 15:27:28 +0000 (08:27 -0700)
committerGeorgi Gerganov <redacted>
Tue, 9 Apr 2024 17:28:26 +0000 (20:28 +0300)
* Allow a regular expression to describe tokens to suppress.

Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens.

Technique inspired by https://github.com/openai/whisper/discussions/1041

Co-authored-by: Georgi Gerganov <redacted>
* Blind change to fix Java test.

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index 42b067e718d4d4bb116f8a4c9f3eb1f29d8281d5..af8b5ca4e013f8cd1c55bce77857c0b9a9ef47f7 100644 (file)
@@ -6,6 +6,7 @@
 #include <cmath>
 #include <fstream>
 #include <cstdio>
+#include <regex>
 #include <string>
 #include <thread>
 #include <vector>
@@ -78,6 +79,9 @@ struct whisper_params {
     // [TDRZ] speaker turn string
     std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
 
+    // A regular expression that matches tokens to suppress
+    std::string suppress_regex;
+
     std::string openvino_encode_device = "CPU";
 
     std::string dtw = "";
@@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
+        else if (                  arg == "--suppress-regex")  { params.suppress_regex = argv[++i]; }
         else if (                  arg == "--grammar")         { params.grammar         = argv[++i]; }
         else if (                  arg == "--grammar-rule")    { params.grammar_rule    = argv[++i]; }
         else if (                  arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
+    fprintf(stderr, "  --suppress-regex REGEX         [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
     fprintf(stderr, "  --grammar GRAMMAR              [%-7s] GBNF grammar to guide decoding\n",                 params.grammar.c_str());
     fprintf(stderr, "  --grammar-rule RULE            [%-7s] top-level GBNF grammar rule name\n",               params.grammar_rule.c_str());
     fprintf(stderr, "  --grammar-penalty N            [%-7.1f] scales down logits of nongrammar tokens\n",      params.grammar_penalty);
@@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) {
 
             wparams.tdrz_enable      = params.tinydiarize; // [TDRZ]
 
+            wparams.suppress_regex   = params.suppress_regex.c_str();
+
             wparams.initial_prompt   = params.prompt.c_str();
 
             wparams.greedy.best_of        = params.best_of;
index d50c788b3c6a2dc9e696414cee56ecfd9aaa36b8..fd9737379dbb3e77306d2b2e076e15612cd7f2b7 100644 (file)
@@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.tdrz_enable       =*/ false,
 
+        /* suppress_regex    =*/ nullptr,
+
         /*.initial_prompt    =*/ nullptr,
         /*.prompt_tokens     =*/ nullptr,
         /*.prompt_n_tokens   =*/ 0,
@@ -4796,6 +4798,17 @@ static void whisper_process_logits(
             params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
         }
 
+        // suppress any tokens matching a regular expression
+        // ref: https://github.com/openai/whisper/discussions/1041
+        if (params.suppress_regex != nullptr) {
+            std::regex re(params.suppress_regex);
+            for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
+                if (std::regex_match(token_id.first, re)) {
+                    logits[token_id.second] = -INFINITY;
+                }
+            }
+        }
+
         // suppress non-speech tokens
         // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
         if (params.suppress_non_speech_tokens) {
index bd8d8df828a805fca8445b20edca438da5dc1de2..6a875d3bbb9d34e67b6efe15e9eadc99511d5344 100644 (file)
@@ -505,6 +505,9 @@ extern "C" {
         // [EXPERIMENTAL] [TDRZ] tinydiarize
         bool tdrz_enable;       // enable tinydiarize speaker turn detection
 
+        // A regular expression that matches tokens to suppress
+        const char * suppress_regex;
+
         // tokens to provide to the whisper decoder as initial prompt
         // these are prepended to any existing text context from a previous call
         // use whisper_tokenize() to convert text to tokens