]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : improve decoding strategy (#244)
authorGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 16:31:17 +0000 (18:31 +0200)
committerGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 16:34:35 +0000 (18:34 +0200)
- Clear past prompt when there is very short audio left for processing.
  My observation is that in these cases the decoding tends to repeat and
  hallucinate stuff and I think this is induced by the existing prompt
- When we fail to sample timestamp token, retry by clearing the past
  prompt. If it fails again, then we advance the window by 1 second

whisper.cpp

index 1bc799676bddf4829cf253f01fed33597de5fdd0..da35456a774156ee71f9e989b6969979af159d4c 100644 (file)
@@ -2650,10 +2650,17 @@ int whisper_full(
             }
         }
 
+        // of only 1 second left, then stop
         if (seek + 100 >= seek_end) {
             break;
         }
 
+        // if there is a very short audio segment left to process, we remove any past prompt since it tends
+        // to confuse the decoder and often make it repeat or hallucinate stuff
+        if (seek > seek_start && seek + 500 >= seek_end) {
+            prompt_past.clear();
+        }
+
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
                 fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
@@ -2780,8 +2787,14 @@ int whisper_full(
         }
 
         if (failed) {
-            fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
-            seek += 100;
+            // when we fail to sample timestamp token, retry by clearing the past prompt
+            // if it fails again, then we advance the window by 1 second
+            if (prompt_past.size() > 0) {
+                prompt_past.clear();
+            } else {
+                fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
+                seek += 100;
+            }
             continue;
         }