]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : optimize fft() function (#2242)
authormky_coder <redacted>
Tue, 18 Jun 2024 15:10:33 +0000 (23:10 +0800)
committerGitHub <redacted>
Tue, 18 Jun 2024 15:10:33 +0000 (18:10 +0300)
Co-authored-by: Mike Fan <redacted>
whisper.cpp

index 4b96e8bcb6615289de30b192811339aa6ee05827..d10083ec97b9eb5cdae2d65b222f2fac42349c8f 100644 (file)
@@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
 // naive Discrete Fourier Transform
 // input is real-valued
 // output is complex-valued
-static void dft(const std::vector<float> & in, std::vector<float> & out) {
-    int N = in.size();
-
-    out.resize(N*2);
+static void dft(const float* in, int N, float* out) {
     const int sin_cos_step = SIN_COS_N_COUNT / N;
 
     for (int k = 0; k < N; k++) {
@@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
 // poor man's implementation - use something better
 // input is real-valued
 // output is complex-valued
-static void fft(const std::vector<float> & in, std::vector<float> & out) {
-    out.resize(in.size()*2);
-
-    int N = in.size();
-
+static void fft(float* in, int N, float* out) {
     if (N == 1) {
         out[0] = in[0];
         out[1] = 0;
         return;
     }
 
-    if (N%2 == 1) {
-        dft(in, out);
+    const int half_N = N / 2;
+    if (N - half_N*2 == 1) {
+        dft(in, N, out);
         return;
     }
 
-    std::vector<float> even;
-    std::vector<float> odd;
-
-    even.reserve(N/2);
-    odd.reserve(N/2);
-
-    for (int i = 0; i < N; i++) {
-        if (i % 2 == 0) {
-            even.push_back(in[i]);
-        } else {
-            odd.push_back(in[i]);
-        }
+    float* even = in + N;
+    for (int i = 0; i < half_N; ++i) {
+        even[i]= in[2*i];
     }
+    float* even_fft = out + 2 * N;
+    fft(even, half_N, even_fft);
 
-    std::vector<float> even_fft;
-    std::vector<float> odd_fft;
-
-    fft(even, even_fft);
-    fft(odd, odd_fft);
+    float* odd = even;
+    for (int i = 0; i < half_N; ++i) {
+        odd[i] = in[2*i + 1];
+    }
+    float* odd_fft = even_fft + N;
+    fft(odd, half_N, odd_fft);
 
     const int sin_cos_step = SIN_COS_N_COUNT / N;
-    for (int k = 0; k < N/2; k++) {
+    for (int k = 0; k < half_N; k++) {
         int idx = k * sin_cos_step; // t = 2*M_PI*k/N
         float re = global_cache.cos_vals[idx]; // cos(t)
         float im = -global_cache.sin_vals[idx]; // sin(t)
@@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
         out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
         out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
 
-        out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
-        out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
+        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
     }
 }
 
@@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
                                               const whisper_filters & filters, whisper_mel_data & mel) {
     const auto frame_size = WHISPER_N_FFT;
     const auto frame_step = WHISPER_HOP_LENGTH;
-    std::vector<float> fft_in(frame_size, 0.0);
-    std::vector<float> fft_out(2 * frame_size);
+    std::vector<float> fft_in(frame_size * 2, 0.0);
+    std::vector<float> fft_out(frame_size * 2 * 2 * 2);
     int n_fft = filters.n_fft;
     int i = ith;
 
@@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
         }
 
         // FFT
-        fft(fft_in, fft_out);
+        fft(fft_in.data(), frame_size, fft_out.data());
 
         // Calculate modulus^2 of complex numbers
         // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.