]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : do not launch log_mel threads when n_thread is 1 (#763)
authorMaximiliano Levi <redacted>
Fri, 14 Apr 2023 19:35:34 +0000 (16:35 -0300)
committerGitHub <redacted>
Fri, 14 Apr 2023 19:35:34 +0000 (22:35 +0300)
whisper.cpp

index 846d3a93dbe316b37cb5611ce8a94acb3f50b6cc..0d677153c754883bd3172e5857388db64b4e1a1e 100644 (file)
@@ -2284,6 +2284,60 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     }
 }
 
+static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
+                                              int n_samples, int fft_size, int fft_step, int n_threads,
+                                              const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
+    std::vector<float> fft_in(fft_size, 0.0);
+    std::vector<float> fft_out(2 * fft_size);
+    int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
+    
+    for (int i = ith; i < mel.n_len; i += n_threads) {
+        const int offset = i * fft_step;
+        
+        // apply Hanning window
+        for (int j = 0; j < fft_size; j++) {
+            if (offset + j < n_samples) {
+                fft_in[j] = hann[j] * samples[offset + j];
+            } else {
+                fft_in[j] = 0.0;
+            }
+        }
+        
+        // FFT -> mag^2
+        fft(fft_in, fft_out);
+        
+        for (int j = 0; j < fft_size; j++) {
+            fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
+        }
+        for (int j = 1; j < fft_size / 2; j++) {
+            fft_out[j] += fft_out[fft_size - j];
+        }
+        
+        if (speed_up) {
+            // scale down in the frequency domain results in a speed up in the time domain
+            for (int j = 0; j < n_fft; j++) {
+                fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
+            }
+        }
+        
+        // mel spectrogram
+        for (int j = 0; j < mel.n_mel; j++) {
+            double sum = 0.0;
+            
+            for (int k = 0; k < n_fft; k++) {
+                sum += fft_out[k] * filters.data[j * n_fft + k];
+            }
+            if (sum < 1e-10) {
+                sum = 1e-10;
+            }
+            
+            sum = log10(sum);
+            
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
+}
+
 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
 static bool log_mel_spectrogram(
           whisper_state & wstate,
@@ -2310,81 +2364,22 @@ static bool log_mel_spectrogram(
     mel.n_len = (n_samples)/fft_step;
     mel.data.resize(mel.n_mel*mel.n_len);
 
-    const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
-
     //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
     //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
 
-    std::vector<std::thread> workers(n_threads);
-    for (int iw = 0; iw < n_threads; ++iw) {
-        workers[iw] = std::thread([&](int ith) {
-            std::vector<float> fft_in;
-            fft_in.resize(fft_size);
-            for (int i = 0; i < fft_size; i++) {
-                fft_in[i] = 0.0;
-            }
-
-            std::vector<float> fft_out;
-            fft_out.resize(2*fft_size);
-
-            for (int i = ith; i < mel.n_len; i += n_threads) {
-                const int offset = i*fft_step;
-
-                // apply Hanning window
-                for (int j = 0; j < fft_size; j++) {
-                    if (offset + j < n_samples) {
-                        fft_in[j] = hann[j]*samples[offset + j];
-                    } else {
-                        fft_in[j] = 0.0;
-                    }
-                }
-
-                // FFT -> mag^2
-                fft(fft_in, fft_out);
-
-                for (int j = 0; j < fft_size; j++) {
-                    fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
-                }
-                for (int j = 1; j < fft_size/2; j++) {
-                    //if (i == 0) {
-                    //    printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
-                    //}
-                    fft_out[j] += fft_out[fft_size - j];
-                }
-                if (i == 0) {
-                    //for (int j = 0; j < fft_size; j++) {
-                    //    printf("%d: %e\n", j, fft_out[j]);
-                    //}
-                }
-
-                if (speed_up) {
-                    // scale down in the frequency domain results in a speed up in the time domain
-                    for (int j = 0; j < n_fft; j++) {
-                        fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
-                    }
-                }
-
-                // mel spectrogram
-                for (int j = 0; j < mel.n_mel; j++) {
-                    double sum = 0.0;
-
-                    for (int k = 0; k < n_fft; k++) {
-                        sum += fft_out[k]*filters.data[j*n_fft + k];
-                    }
-                    if (sum < 1e-10) {
-                        sum = 1e-10;
-                    }
-
-                    sum = log10(sum);
-
-                    mel.data[j*mel.n_len + i] = sum;
-                }
-            }
-        }, iw);
-    }
+    if (n_threads == 1) {
+        log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
+    } else {
+        std::vector<std::thread> workers(n_threads);
+        for (int iw = 0; iw < n_threads; ++iw) {
+            workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
+                                      n_samples, fft_size, fft_step, n_threads,
+                                      std::cref(filters), speed_up, std::ref(mel));
+        }
 
-    for (int iw = 0; iw < n_threads; ++iw) {
-        workers[iw].join();
+        for (int iw = 0; iw < n_threads; ++iw) {
+            workers[iw].join();
+        }
     }
 
     // clamping and normalization