]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : suppress non-speech-related token outputs (#473)
authorshibukazu <redacted>
Wed, 8 Feb 2023 07:05:34 +0000 (16:05 +0900)
committerGitHub <redacted>
Wed, 8 Feb 2023 07:05:34 +0000 (09:05 +0200)
* add non-speech-token suppression

* add suppress non-speech_tokens param

whisper.cpp
whisper.h

index aebb4813f080c8f0d1584597adecd5e26986d75c..24e16bd5cf90e9039ae1b054bf7eb57501a0e382 100644 (file)
@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.language         =*/ "en",
 
         /*.suppress_blank   =*/ true,
+        /*.suppress_non_speech_tokens =*/true,
 
         /*.temperature      =*/  0.0f,
         /*.max_initial_ts   =*/  1.0f,
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
     return res;
 }
 
+static const std::vector<std::string> non_speech_tokens
+{
+    "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
+    "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
+    "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
+    "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
+};
+
 // process the logits for the selected decoder
 // - applies logit filters
 // - computes logprobs and probs
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
         logits[vocab.token_translate]  = -INFINITY;
         logits[vocab.token_transcribe] = -INFINITY;
 
+
+        // suppress non-speech tokens
+        // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
+        if (params.suppress_non_speech_tokens)
+        {
+            for (const std::string &token : non_speech_tokens)
+            {
+                std::string suppress_tokens[] = {token, " " + token};
+                for (const std::string &suppress_token : suppress_tokens)
+                {
+                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
+                    {
+                        logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
+                    }
+                }
+            }
+            // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
+            {
+                logits[vocab.token_to_id.at(" -")] = -INFINITY;
+            }
+            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
+            {
+                logits[vocab.token_to_id.at(" '")] = -INFINITY;
+            }
+        }
+
         // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
         // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
         {
index 786d67d9cb437d7425e1b8aa5135770910245d5e..7eece797c16b84f31116289aaa50e84afb7c4fa6 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -285,6 +285,7 @@ extern "C" {
 
         // common decoding parameters:
         bool suppress_blank;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
+        bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
 
         float temperature;      // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
         float max_initial_ts;   // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97