]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : slightly faster Log Mel computation + n-1 FFT threads (#568)
authorGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 11:18:46 +0000 (14:18 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 11:18:46 +0000 (14:18 +0300)
whisper.cpp

index 178766e8b1507bd17ebcd9908ff8057806f7a37f..8e9fa6cded8abe49794a579b7a90f2d948e76dac 100644 (file)
@@ -2306,10 +2306,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
     std::vector<float> fft_in(fft_size, 0.0);
     std::vector<float> fft_out(2 * fft_size);
     int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
-    
+
     for (int i = ith; i < mel.n_len; i += n_threads) {
         const int offset = i * fft_step;
-        
+
         // apply Hanning window
         for (int j = 0; j < fft_size; j++) {
             if (offset + j < n_samples) {
@@ -2318,37 +2318,49 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
                 fft_in[j] = 0.0;
             }
         }
-        
+
         // FFT -> mag^2
         fft(fft_in, fft_out);
-        
+
         for (int j = 0; j < fft_size; j++) {
             fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
         }
         for (int j = 1; j < fft_size / 2; j++) {
             fft_out[j] += fft_out[fft_size - j];
         }
-        
+
         if (speed_up) {
             // scale down in the frequency domain results in a speed up in the time domain
             for (int j = 0; j < n_fft; j++) {
                 fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
             }
         }
-        
+
         // mel spectrogram
         for (int j = 0; j < mel.n_mel; j++) {
             double sum = 0.0;
-            
-            for (int k = 0; k < n_fft; k++) {
+
+            // unroll loop (suggested by GH user @lunixbochs)
+            int k = 0;
+            for (k = 0; k < n_fft - 3; k += 4) {
+                sum +=
+                    fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
+                    fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
+                    fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
+                    fft_out[k + 3] * filters.data[j*n_fft + k + 3];
+            }
+
+            // handle n_fft remainder
+            for (; k < n_fft; k++) {
                 sum += fft_out[k] * filters.data[j * n_fft + k];
             }
+
             if (sum < 1e-10) {
                 sum = 1e-10;
             }
-            
+
             sum = log10(sum);
-            
+
             mel.data[j * mel.n_len + i] = sum;
         }
     }
@@ -2383,17 +2395,19 @@ static bool log_mel_spectrogram(
     //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
     //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
 
-    if (n_threads == 1) {
-        log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
-    } else {
-        std::vector<std::thread> workers(n_threads);
-        for (int iw = 0; iw < n_threads; ++iw) {
-            workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
-                                      n_samples, fft_size, fft_step, n_threads,
-                                      std::cref(filters), speed_up, std::ref(mel));
+    {
+        std::vector<std::thread> workers(n_threads - 1);
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
+            workers[iw] = std::thread(
+                    log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
+                    n_samples, fft_size, fft_step, n_threads,
+                    std::cref(filters), speed_up, std::ref(mel));
         }
 
-        for (int iw = 0; iw < n_threads; ++iw) {
+        // main thread
+        log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
+
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw].join();
         }
     }