]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : support no_speech_thold (#2625)
authorKarthick <redacted>
Tue, 17 Dec 2024 17:15:47 +0000 (22:45 +0530)
committerGitHub <redacted>
Tue, 17 Dec 2024 17:15:47 +0000 (19:15 +0200)
* Implement no_speech_thold

no_speech_thold functionality is on par with OpenAI's whisper

* Addressed review comments

include/whisper.h
src/whisper.cpp

index 9188d686a3189a2bbfddc4cc70de299b9ff6f37e..71949bdd39727fb1fe92391bfcfbd4309e1e9660 100644 (file)
@@ -534,7 +534,7 @@ extern "C" {
         float temperature_inc;
         float entropy_thold;    // similar to OpenAI's "compression_ratio_threshold"
         float logprob_thold;
-        float no_speech_thold;  // TODO: not implemented
+        float no_speech_thold;
 
         struct {
             int best_of;    // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
index 810a8d267aba2520673d18f0e21e06842ba047f7..bcc530ae8918f6806316b4b0e3da055ac8ae1545 100644 (file)
@@ -867,6 +867,7 @@ struct whisper_state {
     whisper_token tid_last;
 
     std::vector<float> energy; // PCM signal energy
+    float no_speech_prob = 0.0f;
 
     // [EXPERIMENTAL] Token-level timestamps with DTW
     whisper_aheads_masks aheads_masks;
@@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
     "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
 };
 
+static void whisper_compute_logprobs(
+                const std::vector<float> & logits,
+                              const int    n_logits,
+                      std::vector<float> & logprobs) {
+    const float logit_max = *std::max_element(logits.begin(), logits.end());
+    float logsumexp = 0.0f;
+    for (int i = 0; i < n_logits; ++i) {
+        if (logits[i] > -INFINITY) {
+            logsumexp += expf(logits[i] - logit_max);
+        }
+    }
+    logsumexp = logf(logsumexp) + logit_max;
+
+    for (int i = 0; i < n_logits; ++i) {
+        if (logits[i] > -INFINITY) {
+            logprobs[i] = logits[i] - logsumexp;
+        } else {
+            logprobs[i] = -INFINITY;
+        }
+    }
+}
+
+static void whisper_compute_probs(
+    const std::vector<float> & logits,
+                  const int    n_logits,
+    const std::vector<float> & logprobs,
+          std::vector<float> & probs)     {
+    for (int i = 0; i < n_logits; ++i) {
+        if (logits[i] == -INFINITY) {
+            probs[i] = 0.0f;
+        } else {
+            probs[i] = expf(logprobs[i]);
+        }
+    }
+}
+
 // process the logits for the selected decoder
 // - applies logit filters
 // - computes logprobs and probs
@@ -4886,7 +4923,7 @@ static void whisper_process_logits(
 
         // suppress sot and nosp tokens
         logits[vocab.token_sot]  = -INFINITY;
-        logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
+        logits[vocab.token_nosp] = -INFINITY;
 
         // [TDRZ] when tinydiarize is disabled, suppress solm token
         if (params.tdrz_enable == false) {
@@ -4985,24 +5022,7 @@ static void whisper_process_logits(
         }
 
         // populate the logprobs array (log_softmax)
-        {
-            const float logit_max = *std::max_element(logits.begin(), logits.end());
-            float logsumexp = 0.0f;
-            for (int i = 0; i < n_logits; ++i) {
-                if (logits[i] > -INFINITY) {
-                    logsumexp += expf(logits[i] - logit_max);
-                }
-            }
-            logsumexp = logf(logsumexp) + logit_max;
-
-            for (int i = 0; i < n_logits; ++i) {
-                if (logits[i] > -INFINITY) {
-                    logprobs[i] = logits[i] - logsumexp;
-                } else {
-                    logprobs[i] = -INFINITY;
-                }
-            }
-        }
+        whisper_compute_logprobs(logits, n_logits, logprobs);
 
         // if sum of probability over timestamps is above any other token, sample timestamp
         // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -5060,15 +5080,7 @@ static void whisper_process_logits(
     }
 
     // compute probs
-    {
-        for (int i = 0; i < n_logits; ++i) {
-            if (logits[i] == -INFINITY) {
-                probs[i] = 0.0f;
-            } else {
-                probs[i] = expf(logprobs[i]);
-            }
-        }
-    }
+    whisper_compute_probs(logits, n_logits, logprobs, probs);
 
 #if 0
     // print first 100 logits - token string : logit
@@ -5647,6 +5659,18 @@ int whisper_full_with_state(
                     return -8;
                 }
 
+                // Calculate no_speech probability after first decode.
+                // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
+                {
+                    const int n_logits = ctx->vocab.id_to_token.size();
+                    std::vector<float> logprobs(n_logits);
+                    std::vector<float> probs(n_logits);
+
+                    whisper_compute_logprobs(state->logits, n_logits, logprobs);
+                    whisper_compute_probs(state->logits, n_logits, logprobs, probs);
+                    state->no_speech_prob = probs[whisper_token_nosp(ctx)];
+                }
+
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
@@ -6038,8 +6062,9 @@ int whisper_full_with_state(
             if (it != (int) temperatures.size() - 1) {
                 const auto & decoder = state->decoders[best_decoder_id];
 
-                if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
-                    WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
+                if (decoder.failed ||
+                    (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
+                    WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
                     success = false;
                     state->n_fail_p++;
                 }
@@ -6068,6 +6093,9 @@ int whisper_full_with_state(
             // [EXPERIMENTAL] Token-level timestamps with DTW
             const auto n_segments_before = state->result_all.size();
 
+            const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
+                best_decoder.sequence.avg_logprobs < params.logprob_thold);
+
             //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
 
             // update prompt_past
@@ -6076,11 +6104,11 @@ int whisper_full_with_state(
                 prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
             }
 
-            for (int i = 0; i < result_len; ++i) {
+            for (int i = 0; i < result_len && !is_no_speech; ++i) {
                 prompt_past.push_back(tokens_cur[i].id);
             }
 
-            if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
+            if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
                 int  i0 = 0;
                 auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));