]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : by default disable non-speech tokens suppression (#473)
authorGeorgi Gerganov <redacted>
Wed, 15 Feb 2023 19:48:49 +0000 (21:48 +0200)
committerGeorgi Gerganov <redacted>
Wed, 15 Feb 2023 19:48:49 +0000 (21:48 +0200)
This seems to be causing hallucinations in the end of the audio, e.g.:

"Thank you for listening"
"Amen"
..

whisper.cpp

index 331d4084c6b7ed7a17bfb6462c2bb8fc2f735055..04cbc36b2ca909c778cdeaef8968ba396aeb629c 100644 (file)
@@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.language         =*/ "en",
 
         /*.suppress_blank   =*/ true,
-        /*.suppress_non_speech_tokens =*/true,
+        /*.suppress_non_speech_tokens =*/ false,
 
         /*.temperature      =*/  0.0f,
         /*.max_initial_ts   =*/  1.0f,
@@ -3078,8 +3078,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
     return res;
 }
 
-static const std::vector<std::string> non_speech_tokens
-{
+static const std::vector<std::string> non_speech_tokens = {
     "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
     "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
     "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
@@ -3149,26 +3148,21 @@ static void whisper_process_logits(
 
         // suppress non-speech tokens
         // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
-        if (params.suppress_non_speech_tokens)
-        {
-            for (const std::string &token : non_speech_tokens)
-            {
-                std::string suppress_tokens[] = {token, " " + token};
-                for (const std::string &suppress_token : suppress_tokens)
-                {
-                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
-                    {
+        if (params.suppress_non_speech_tokens) {
+            for (const std::string & token : non_speech_tokens) {
+                const std::string suppress_tokens[] = {token, " " + token};
+                for (const std::string & suppress_token : suppress_tokens) {
+                    if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
                         logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
                     }
                 }
             }
+
             // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
-            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
-            {
+            if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
                 logits[vocab.token_to_id.at(" -")] = -INFINITY;
             }
-            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
-            {
+            if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
                 logits[vocab.token_to_id.at(" '")] = -INFINITY;
             }
         }