]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add option to speed up the audio tempo by x2
authorGeorgi Gerganov <redacted>
Sat, 12 Nov 2022 16:03:49 +0000 (18:03 +0200)
committerGeorgi Gerganov <redacted>
Sun, 13 Nov 2022 14:25:43 +0000 (16:25 +0200)
Using a Phase Vocoder for speeding up the audio tempo by scaling down
the frequencies in the frequency domain.

This reduces the computation in the Encoder by a factor of 2.
The transcription accuracy is degraded, but for slow to normal speech -
it seems to be still very good.

I think this can find application for real-time transcription - i.e. the
"stream" example.

examples/main/main.cpp
examples/stream/stream.cpp
whisper.cpp
whisper.h

index 70580315769a5110b1dc67ee90433c3dd4a5f057..204703c4d9de25c916823de7c7e170a277d3ef47 100644 (file)
@@ -59,6 +59,7 @@ struct whisper_params {
 
     float word_thold = 0.01f;
 
+    bool speed_up             = false;
     bool verbose              = false;
     bool translate            = false;
     bool output_txt           = false;
@@ -104,6 +105,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.max_len = std::stoi(argv[++i]);
         } else if (arg == "-wt" || arg == "--word-thold") {
             params.word_thold = std::stof(argv[++i]);
+        } else if (arg == "-su" || arg == "--speed-up") {
+            params.speed_up = true;
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -161,6 +164,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
     fprintf(stderr, "  -ml N,    --max-len N      maximum segment length in characters (default: %d)\n", params.max_len);
     fprintf(stderr, "  -wt N,    --word-thold N   word timestamp probability threshold (default: %f)\n", params.word_thold);
+    fprintf(stderr, "  -su,      --speed-up       speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false");
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
@@ -454,7 +458,7 @@ int main(int argc, char ** argv) {
         std::vector<float> pcmf32;
         {
             drwav wav;
-            
+
             if (fname_inp == "-") {
                 std::vector<uint8_t> wav_data;
                 {
@@ -563,6 +567,8 @@ int main(int argc, char ** argv) {
             wparams.thold_pt             = params.word_thold;
             wparams.max_len              = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
 
+            wparams.speed_up             = params.speed_up;
+
             // this callback is called on each new segment
             if (!wparams.print_realtime) {
                 wparams.new_segment_callback           = whisper_print_segment_callback;
index c39375a08fba1c7432eb2d51e423c5999d3fd40f..718c8151d39d4e413e454dd8b6cfb42765591932 100644 (file)
@@ -41,6 +41,7 @@ struct whisper_params {
     int32_t length_ms  = 10000;
     int32_t capture_id = -1;
 
+    bool speed_up             = false;
     bool verbose              = false;
     bool translate            = false;
     bool no_context           = true;
@@ -68,6 +69,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.length_ms = std::stoi(argv[++i]);
         } else if (arg == "-c" || arg == "--capture") {
             params.capture_id = std::stoi(argv[++i]);
+        } else if (arg == "-su" || arg == "--speed-up") {
+            params.speed_up = true;
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -113,6 +116,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "            --step N         audio step size in milliseconds (default: %d)\n", params.step_ms);
     fprintf(stderr, "            --length N       audio length in milliseconds (default: %d)\n", params.length_ms);
     fprintf(stderr, "  -c ID,    --capture ID     capture device ID (default: -1)\n");
+    fprintf(stderr, "  -su,      --speed-up       speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false");
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -kc,      --keep-context   keep text context from earlier audio (default: false)\n");
@@ -326,6 +330,8 @@ int main(int argc, char ** argv) {
             wparams.language             = params.language.c_str();
             wparams.n_threads            = params.n_threads;
 
+            wparams.speed_up             = params.speed_up;
+
             if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 6;
index 7078863aa3e28e68c2bdda802aa43d038bd31081..d894f69cd45932a2417d7384f3b545a8329359b9 100644 (file)
@@ -2031,6 +2031,7 @@ static bool log_mel_spectrogram(
     const int n_mel,
     const int n_threads,
     const whisper_filters & filters,
+    const bool speed_up,
     whisper_mel & mel) {
 
     // Hanning window
@@ -2044,7 +2045,7 @@ static bool log_mel_spectrogram(
     mel.n_len = (n_samples)/fft_step;
     mel.data.resize(mel.n_mel*mel.n_len);
 
-    const int n_fft = 1 + fft_size/2;
+    const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
 
     //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
     //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
@@ -2091,6 +2092,13 @@ static bool log_mel_spectrogram(
                     //}
                 }
 
+                if (speed_up) {
+                    // scale down in the frequency domain results in a speed up in the time domain
+                    for (int j = 0; j < n_fft; j++) {
+                        fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
+                    }
+                }
+
                 // mel spectrogram
                 for (int j = 0; j < mel.n_mel; j++) {
                     double sum = 0.0;
@@ -2171,7 +2179,21 @@ void whisper_free(struct whisper_context * ctx) {
 int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
     const int64_t t_start_us = ggml_time_us();
 
-    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
+    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+        return -1;
+    }
+
+    ctx->t_mel_us = ggml_time_us() - t_start_us;
+
+    return 0;
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
+    if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
@@ -2353,6 +2375,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.thold_ptsum          =*/ 0.01f,
                     /*.max_len              =*/ 0,
 
+                    /*.speed_up             =*/ false,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2391,6 +2415,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.thold_ptsum          =*/ 0.01f,
                     /*.max_len              =*/ 0,
 
+                    /*.speed_up             =*/ false,
+
                     /*.language             =*/ "en",
 
                     /*.greedy               =*/ {
@@ -2485,9 +2511,16 @@ int whisper_full(
     result_all.clear();
 
     // compute log mel spectrogram
-    if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
-        fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
-        return -1;
+    if (params.speed_up) {
+        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -1;
+        }
+    } else {
+        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+            fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
+            return -1;
+        }
     }
 
     if (params.token_timestamps) {
@@ -2673,16 +2706,19 @@ int whisper_full(
                 if (tokens_cur[i].id > whisper_token_beg(ctx)) {
                     const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
                     if (!text.empty()) {
+                        const auto tt0 = params.speed_up ? 2*t0 : t0;
+                        const auto tt1 = params.speed_up ? 2*t1 : t1;
+
                         if (params.print_realtime) {
                             if (params.print_timestamps) {
-                                printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+                                printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
                             } else {
                                 printf("%s", text.c_str());
                                 fflush(stdout);
                             }
                         }
 
-                        result_all.push_back({ t0, t1, text, {} });
+                        result_all.push_back({ tt0, tt1, text, {} });
                         for (int j = i0; j <= i; j++) {
                             result_all.back().tokens.push_back(tokens_cur[j]);
                         }
@@ -2714,16 +2750,19 @@ int whisper_full(
             if (!text.empty()) {
                 const auto t1 = seek + seek_delta;
 
+                const auto tt0 = params.speed_up ? 2*t0 : t0;
+                const auto tt1 = params.speed_up ? 2*t1 : t1;
+
                 if (params.print_realtime) {
                     if (params.print_timestamps) {
-                        printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
+                        printf("[%s --> %s]  %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
                     } else {
                         printf("%s", text.c_str());
                         fflush(stdout);
                     }
                 }
 
-                result_all.push_back({ t0, t1, text, {} });
+                result_all.push_back({ tt0, tt1, text, {} });
                 for (int j = i0; j < (int) tokens_cur.size(); j++) {
                     result_all.back().tokens.push_back(tokens_cur[j]);
                 }
index 4c112f49c0ca8ca6fa8653ae9f3ff0ee852867a3..ea677eafd0c05a6b5a167917eef6c7bad12f270a 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -202,6 +202,9 @@ extern "C" {
         float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
         int   max_len;          // max segment length in characters
 
+        // [EXPERIMENTAL] speed-up techniques
+        bool speed_up; // speed-up the audio by 2x using Phase Vocoder
+
         const char * language;
 
         struct {