]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : pad audio instead of spectrogram (#579)
authorGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 14:18:43 +0000 (17:18 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 14:19:19 +0000 (17:19 +0300)
Also, fallback only if more temperatures are available and if we are
at least 3 seconds before the end of the audio

whisper.cpp

index 3ba2f8e13ec67e516b04d55d006983acb4e02149..3137373b918a72adccf2248b46a53d2acf13d79d 100644 (file)
@@ -297,6 +297,7 @@ static const std::map<e_model, size_t> MEM_REQ_DECODE = {
 
 struct whisper_mel {
     int n_len;
+    int n_len_org;
     int n_mel;
 
     std::vector<float> data;
@@ -2388,8 +2389,28 @@ static bool log_mel_spectrogram(
         hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
     }
 
-    mel.n_mel = n_mel;
-    mel.n_len = (n_samples)/fft_step;
+    mel.n_mel     = n_mel;
+    mel.n_len     = n_samples/fft_step;
+    mel.n_len_org = mel.n_len;
+
+    std::vector<float> samples_padded;
+
+    // pad audio with at least one extra chunk of zeros
+    {
+        const int pad = (100*WHISPER_CHUNK_SIZE)/2;
+
+        if (mel.n_len % pad != 0) {
+            mel.n_len = (mel.n_len/pad + 1)*pad;
+        }
+        mel.n_len += pad;
+
+        samples_padded.resize(mel.n_len*fft_step);
+        memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
+        memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
+
+        samples = samples_padded.data();
+    }
+
     mel.data.resize(mel.n_mel*mel.n_len);
 
     //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
@@ -2433,6 +2454,8 @@ static bool log_mel_spectrogram(
 
     wstate.t_mel_us += ggml_time_us() - t_start_us;
 
+    //printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
+
     return true;
 }
 
@@ -2786,8 +2809,9 @@ int whisper_set_mel_with_state(
         return -1;
     }
 
-    state->mel.n_len = n_len;
-    state->mel.n_mel = n_mel;
+    state->mel.n_len     = n_len;
+    state->mel.n_len_org = n_len;
+    state->mel.n_mel     = n_mel;
 
     state->mel.data.resize(n_len*n_mel);
     memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
@@ -2913,8 +2937,8 @@ int whisper_lang_auto_detect_with_state(
         return -1;
     }
 
-    if (seek >= state->mel.n_len) {
-        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
+    if (seek >= state->mel.n_len_org) {
+        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
         return -2;
     }
 
@@ -3049,11 +3073,11 @@ const char *whisper_model_type_readable(struct whisper_context * ctx) {
 }
 
 int whisper_n_len_from_state(struct whisper_state * state) {
-    return state->mel.n_len;
+    return state->mel.n_len_org;
 }
 
 int whisper_n_len(struct whisper_context * ctx) {
-    return ctx->state->mel.n_len;
+    return ctx->state->mel.n_len_org;
 }
 
 int whisper_n_vocab(struct whisper_context * ctx) {
@@ -4354,7 +4378,11 @@ int whisper_full_with_state(
             }
 
             // was the decoding successful for the current temperature?
-            {
+            // do fallback only if:
+            // - we are not at the last temperature
+            // - we are not at the end of the audio (3 sec)
+            if (it != (int) temperatures.size() - 1 &&
+                seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
                 bool success = true;
 
                 const auto & decoder = state->decoders[best_decoder_id];