]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper: use global cache for sin/cos vals and Hann window (#2194)
authorBorislav Stanimirov <redacted>
Wed, 29 May 2024 16:09:21 +0000 (19:09 +0300)
committerGitHub <redacted>
Wed, 29 May 2024 16:09:21 +0000 (19:09 +0300)
- also rename Hanning to Hann as it's named after Julius von Hann
 as per Wikipedia

whisper.cpp

index 7b8c683fca72585d0c5be637a703dc03383fed23..a22da8896bb055e09e1e9732ed6b47a9924422d2 100644 (file)
@@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
 }
 
 #define SIN_COS_N_COUNT WHISPER_N_FFT
-static float sin_vals[SIN_COS_N_COUNT];
-static float cos_vals[SIN_COS_N_COUNT];
+namespace {
+struct whisper_global_cache {
+    // In FFT, we frequently use sine and cosine operations with the same values.
+    // We can use precalculated values to speed up the process.
+    float sin_vals[SIN_COS_N_COUNT];
+    float cos_vals[SIN_COS_N_COUNT];
+
+    // Hann window (Use cosf to eliminate difference)
+    // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
+    // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
+    float hann_window[WHISPER_N_FFT];
+    float hann_window2x[WHISPER_N_FFT * 2];
+
+    whisper_global_cache() {
+        fill_sin_cos_table();
+#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr)
+        FILL_HANN_WINDOW(hann_window);
+        FILL_HANN_WINDOW(hann_window2x);
+    }
+
+    void fill_sin_cos_table() {
+        for (int i = 0; i < SIN_COS_N_COUNT; i++) {
+            double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
+            sin_vals[i] = sinf(theta);
+            cos_vals[i] = cosf(theta);
+        }
+    }
 
-// In FFT, we frequently use sine and cosine operations with the same values.
-// We can use precalculated values to speed up the process.
-static void fill_sin_cos_table() {
-    static bool is_filled = false;
-    if (is_filled) return;
-    for (int i = 0; i < SIN_COS_N_COUNT; i++) {
-        double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
-        sin_vals[i] = sinf(theta);
-        cos_vals[i] = cosf(theta);
+    void fill_hann_window(int length, bool periodic, float* output) {
+        int offset = -1;
+        if (periodic) {
+            offset = 0;
+        }
+        for (int i = 0; i < length; i++) {
+            output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
+        }
     }
-    is_filled = true;
+} global_cache;
 }
 
 // naive Discrete Fourier Transform
@@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
 
         for (int n = 0; n < N; n++) {
             int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
-            re += in[n]*cos_vals[idx]; // cos(t)
-            im -= in[n]*sin_vals[idx]; // sin(t)
+            re += in[n]*global_cache.cos_vals[idx]; // cos(t)
+            im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
         }
 
         out[k*2 + 0] = re;
@@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     const int sin_cos_step = SIN_COS_N_COUNT / N;
     for (int k = 0; k < N/2; k++) {
         int idx = k * sin_cos_step; // t = 2*M_PI*k/N
-        float re = cos_vals[idx]; // cos(t)
-        float im = -sin_vals[idx]; // sin(t)
+        float re = global_cache.cos_vals[idx]; // cos(t)
+        float im = -global_cache.sin_vals[idx]; // sin(t)
 
         float re_odd = odd_fft[2*k + 0];
         float im_odd = odd_fft[2*k + 1];
@@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     }
 }
 
-static bool hann_window(int length, bool periodic, std::vector<float> & output) {
-    if (output.size() < static_cast<size_t>(length)) {
-        output.resize(length);
-    }
-    int offset = -1;
-    if (periodic) {
-        offset = 0;
-    }
-    for (int i = 0; i < length; i++) {
-        output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
-    }
-
-    return true;
-}
-
-static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
+static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
                                               int n_samples, int frame_size, int frame_step, int n_threads,
                                               const whisper_filters & filters, whisper_mel & mel) {
     std::vector<float> fft_in(frame_size, 0.0);
@@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
     for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
         const int offset = i * frame_step;
 
-        // apply Hanning window (~10% faster)
+        // apply Hann window (~10% faster)
         for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
             fft_in[j] = hann[j] * samples[offset + j];
         }
@@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
               whisper_mel & mel) {
     const int64_t t_start_us = ggml_time_us();
 
-    // Hanning window (Use cosf to eliminate difference)
-    // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
-    // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
-    std::vector<float> hann;
-    hann_window(frame_size, true, hann);
-
+    // Hann window
+    const float * hann = nullptr;
+    if (frame_size == WHISPER_N_FFT) {
+        hann = global_cache.hann_window;
+    } else if (frame_size == 2 * WHISPER_N_FFT) {
+        hann = global_cache.hann_window2x;
+    } else {
+        WHISPER_ASSERT(false && "Unsupported frame_size");
+        return false;
+    }
 
     // Calculate the length of padding
     int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
         std::vector<std::thread> workers(n_threads - 1);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw] = std::thread(
-                    log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
+                    log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
                     n_samples + stage_2_pad, frame_size, frame_step, n_threads,
                     std::cref(filters), std::ref(mel));
         }
@@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
 #endif
 
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
-    fill_sin_cos_table();
-
     whisper_state * state = new whisper_state;
 
     state->backend = whisper_backend_init(ctx->params);
@@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
     // operation (after median filter)
     // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
     // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
-    w = ggml_norm(gctx, w, 1e-9);
+    w = ggml_norm(gctx, w, 1e-9f);
     w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
 
     // Pass median filter - this is done over AUDIO_TOKENS dimension.