]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : minor improvemnt in decoding strategy (#244)
authorGeorgi Gerganov <redacted>
Sat, 10 Dec 2022 11:38:26 +0000 (13:38 +0200)
committerGeorgi Gerganov <redacted>
Sat, 10 Dec 2022 11:38:26 +0000 (13:38 +0200)
Do not allow for text segments to go beyond end of audio.
This partially mitigates some issues when the last audio window is 1-2
seconds just before the end of the audio file and the decoding spirals
into a repetition of the last transcribed phrase.

whisper.cpp

index abfc44fee916b059b51c9381fbe226987e978eb4..67451dc80b9b14270f7913431c9e74b5330604a9 100644 (file)
@@ -2687,6 +2687,7 @@ int whisper_full(
         tokens_cur.clear();
 
         bool failed = false;
+        bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
 
         for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
             if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
@@ -2712,13 +2713,13 @@ int whisper_full(
                     const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
 
                     // do not allow to go back in time
-                    if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
-                        seek_delta > seek_delta_new && result_len < i) {
+                    if (has_ts && seek_delta > seek_delta_new && result_len < i) {
                         break;
                     }
 
                     seek_delta = seek_delta_new;
                     result_len = i + 1;
+                    has_ts = true;
                 }
 
                 // add it to the context
@@ -2730,8 +2731,11 @@ int whisper_full(
                 //    printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
                 //}
 
-                // end of text token
-                if (token.id == whisper_token_eot(ctx) || (params.max_tokens > 0 && i > params.max_tokens)) {
+                // end of segment
+                if (token.id == whisper_token_eot(ctx) ||               // end of text token
+                    (params.max_tokens > 0 && i > params.max_tokens) || // max tokens per segment reached
+                    (has_ts && seek + seek_delta + 100 >= seek_end)     // end of audio reached
+                    ) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= seek_end) {
                             result_len = i + 1;