]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd: mtmd_audio_streaming_istft (#18645)
authorTarek Dakhran <redacted>
Tue, 6 Jan 2026 20:00:29 +0000 (21:00 +0100)
committerGitHub <redacted>
Tue, 6 Jan 2026 20:00:29 +0000 (21:00 +0100)
Change is decoupled from https://github.com/ggml-org/llama.cpp/pull/18641.

[LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B)
needs streaming istft for generating output audio.

* add streaming ISTFT class (`mtmd_audio_streaming_istft`) with overlap-add for audio reconstruction
* replace global audio cache with per-instance cache, the model requires
  two independent caches, for preprocessing (audio input) and for istft
  (audio output).
* unified templated FFT/IFFT implementation supporting both forward and inverse transforms

tools/mtmd/mtmd-audio.cpp
tools/mtmd/mtmd-audio.h

index e99101184b1a57c4d2f38298efe204bc12c3e85d..e8eef035ff57de528441f9632e1ce552958a7263 100644 (file)
 #include <fstream>
 #include <algorithm>
 
-// most of the code here is copied from whisper.cpp
+// some of the code here is copied from whisper.cpp
 
 constexpr bool DEBUG = false;
 
-struct mtmd_audio_mel_filters {
-    int32_t n_mel;
-    int32_t n_fft;
-
-    std::vector<float> data;
-};
-
-// note: this global cache is shared among all preprocessors
-//       if we want to use multiple preprocessors at the same time,
-//       we will need to enclose it in the preprocessor class in the future
-static struct mtmd_audio_global_cache {
-    // precomputed sin/cos table for FFT
-    std::vector<float> sin_vals;
-    std::vector<float> cos_vals;
-
-    // hann window
-    std::vector<float> hann_window;
-
-    // mel filter bank
-    mtmd_audio_mel_filters filters;
-
-    void fill_sin_cos_table(int n) {
-        sin_vals.resize(n);
-        cos_vals.resize(n);
-        for (int i = 0; i < n; i++) {
-            double theta = (2 * M_PI * i) / n;
-            sin_vals[i] = sinf(theta);
-            cos_vals[i] = cosf(theta);
-        }
+void mtmd_audio_cache::fill_sin_cos_table(int n) {
+    sin_vals.resize(n);
+    cos_vals.resize(n);
+    for (int i = 0; i < n; i++) {
+        double theta = (2 * M_PI * i) / n;
+        sin_vals[i]  = sinf(theta);
+        cos_vals[i]  = cosf(theta);
     }
+}
 
-    void fill_hann_window(int length, bool periodic) {
-        hann_window.resize(length);
-        int offset = -1;
-        if (periodic) {
-            offset = 0;
-        }
-        for (int i = 0; i < length; i++) {
-            hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
-        }
+void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
+    hann_window.resize(length);
+    int offset = -1;
+    if (periodic) {
+        offset = 0;
+    }
+    for (int i = 0; i < length; i++) {
+        hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
     }
+}
 
-    // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
-    // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
-    void fill_mel_filterbank_matrix(
-        int n_mel,
-        int n_fft,
-        int sample_rate,            // e.g. 16000
-        float fmin = 0.0f,          // e.g. 0.0
-        float fmax = -1.0f,         // e.g. sr/2; pass -1 for auto
-        bool slaney_area_norm = true,
-        float scale = 1.0f          // optional extra scaling; use 1.0f/1000.0f to mimic your code
-    ) {
-        GGML_ASSERT(n_mel > 0 && n_fft > 1);
-        if (fmax <= 0.0f) {
-            fmax = 0.5f * sample_rate;
-        }
+void mtmd_audio_cache::fill_mel_filterbank_matrix(int   n_mel,
+                                                  int   n_fft,
+                                                  int   sample_rate,
+                                                  float fmin,
+                                                  float fmax,
+                                                  bool  slaney_area_norm,
+                                                  float scale) {
+    GGML_ASSERT(n_mel > 0 && n_fft > 1);
+    if (fmax <= 0.0f) {
+        fmax = 0.5f * sample_rate;
+    }
 
-        // Slaney scale (matches librosa default)
-        const double min_log_hz = 1000.0;
-        const double lin_slope = 3 / 200.;
-        const double min_log_mel = min_log_hz * lin_slope;
-        const double log_step = log(6.4) / 27.0;
-        auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
-            return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
-        };
-        auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
-            return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
-        };
-
-        // infer N_fft from n_fft_bins
-        const double bin_hz_step = double(sample_rate) / double(n_fft);
-
-        // mel grid: n_mel + 2 edges
-        const double m_lo = hz_to_mel(fmin);
-        const double m_hi = hz_to_mel(fmax);
-        std::vector<double> mel_pts(n_mel + 2);
-        for (int i = 0; i < n_mel + 2; ++i) {
-            mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
-        }
+    // Slaney scale (matches librosa default)
+    const double min_log_hz  = 1000.0;
+    const double lin_slope   = 3 / 200.;
+    const double min_log_mel = min_log_hz * lin_slope;
+    const double log_step    = log(6.4) / 27.0;
+    auto         hz_to_mel   = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
+        return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
+    };
+    auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
+        return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
+    };
+
+    // infer N_fft from n_fft_bins
+    const double bin_hz_step = double(sample_rate) / double(n_fft);
+
+    // mel grid: n_mel + 2 edges
+    const double        m_lo = hz_to_mel(fmin);
+    const double        m_hi = hz_to_mel(fmax);
+    std::vector<double> mel_pts(n_mel + 2);
+    for (int i = 0; i < n_mel + 2; ++i) {
+        mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
+    }
 
-        // convert to Hz
-        std::vector<double> hz_pts(n_mel + 2);
-        for (int i = 0; i < n_mel + 2; ++i) {
-            hz_pts[i] = mel_to_hz(mel_pts[i]);
-        }
+    // convert to Hz
+    std::vector<double> hz_pts(n_mel + 2);
+    for (int i = 0; i < n_mel + 2; ++i) {
+        hz_pts[i] = mel_to_hz(mel_pts[i]);
+    }
 
-        const int n_fft_bins = n_fft / 2 + 1;
-
-        // filterbank
-        std::vector<float> out(n_mel * n_fft_bins, 0);
-        for (int m = 0; m < n_mel; ++m) {
-            const double f_left   = hz_pts[m];
-            const double f_center = hz_pts[m + 1];
-            const double f_right  = hz_pts[m + 2];
-
-            const double denom_l = std::max(1e-30, f_center - f_left);
-            const double denom_r = std::max(1e-30, f_right  - f_center);
-            const double enorm   = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
-
-            for (int k = 0; k < n_fft_bins; ++k) {
-                const double f = k * bin_hz_step;
-                double w = 0.0;
-                if (f >= f_left && f <= f_center) {
-                    w = (f - f_left) / denom_l;
-                } else if (f > f_center && f <= f_right) {
-                    w = (f_right - f) / denom_r;
-                }
-                out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
+    const int n_fft_bins = n_fft / 2 + 1;
+
+    // filterbank
+    std::vector<float> out(n_mel * n_fft_bins, 0);
+    for (int m = 0; m < n_mel; ++m) {
+        const double f_left   = hz_pts[m];
+        const double f_center = hz_pts[m + 1];
+        const double f_right  = hz_pts[m + 2];
+
+        const double denom_l = std::max(1e-30, f_center - f_left);
+        const double denom_r = std::max(1e-30, f_right - f_center);
+        const double enorm   = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
+
+        for (int k = 0; k < n_fft_bins; ++k) {
+            const double f = k * bin_hz_step;
+            double       w = 0.0;
+            if (f >= f_left && f <= f_center) {
+                w = (f - f_left) / denom_l;
+            } else if (f > f_center && f <= f_right) {
+                w = (f_right - f) / denom_r;
             }
+            out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
         }
+    }
 
-        filters.n_mel = n_mel;
-        filters.n_fft = n_fft;
-        filters.data  = std::move(out);
+    filters.n_mel = n_mel;
+    filters.n_fft = n_fft;
+    filters.data  = std::move(out);
 
-        if (DEBUG) { // debug
-            for (size_t i = 0; i < filters.data.size(); ++i) {
-                if (filters.data[i] != 0.0f) {
-                    printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
-                }
+    if (DEBUG) {  // debug
+        for (size_t i = 0; i < filters.data.size(); ++i) {
+            if (filters.data[i] != 0.0f) {
+                printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
             }
         }
     }
-} g_cache;
+}
 
-// naive Discrete Fourier Transform
-// input is real-valued
-// output is complex-valued
-static void dft(const float * in, int N, float * out) {
-    const int n_sin_cos_vals = g_cache.sin_vals.size();
-    const int sin_cos_step = n_sin_cos_vals / N;
+// Unified DFT implementation for both forward and inverse transforms
+// Template parameters:
+//   Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
+//            true  = IDFT with exp(+2πi·k·n/N), scales by 1/N
+//   RealInput: true = input is real-valued (stride 1), avoids imaginary computations
+//              false = input is complex-valued (interleaved real/imag, stride 2)
+template <bool Inverse, bool RealInput>
+static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
+    const int n_sin_cos_vals = cache.sin_vals.size();
+    const int sin_cos_step   = n_sin_cos_vals / N;
+
+    constexpr float sign  = Inverse ? 1.0f : -1.0f;
+    const float     scale = Inverse ? (1.0f / N) : 1.0f;
 
     for (int k = 0; k < N; k++) {
         float re = 0;
         float im = 0;
 
         for (int n = 0; n < N; n++) {
-            int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
-            re += in[n] * g_cache.cos_vals[idx]; // cos(t)
-            im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
+            int   idx     = (k * n * sin_cos_step) % n_sin_cos_vals;
+            float cos_val = cache.cos_vals[idx];
+            float sin_val = cache.sin_vals[idx];
+
+            if constexpr (RealInput) {
+                // Real input: in_im = 0, simplifies to:
+                // re += in_re * cos_val
+                // im += sign * in_re * sin_val
+                float in_re = in[n];
+                re += in_re * cos_val;
+                im += sign * in_re * sin_val;
+            } else {
+                float in_re = in[n * 2 + 0];
+                float in_im = in[n * 2 + 1];
+                // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
+                re += in_re * cos_val - sign * in_im * sin_val;
+                im += sign * in_re * sin_val + in_im * cos_val;
+            }
         }
 
-        out[k*2 + 0] = re;
-        out[k*2 + 1] = im;
+        out[k * 2 + 0] = re * scale;
+        out[k * 2 + 1] = im * scale;
     }
 }
 
-// Cooley-Tukey FFT
-// poor man's implementation - use something better
-// input is real-valued
-// output is complex-valued
-static void fft(float * in, int N, float * out) {
-    const int n_sin_cos_vals = g_cache.sin_vals.size();
+// Cooley-Tukey FFT/IFFT unified implementation
+// Template parameters:
+//   Inverse: false = FFT with exp(-2πi·k/N), no scaling
+//            true  = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
+//   RealInput: true = input is real-valued (stride 1)
+//              false = input is complex-valued (interleaved real/imag, stride 2)
+template <bool Inverse, bool RealInput>
+static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    const int n_sin_cos_vals = cache.sin_vals.size();
+
     if (N == 1) {
         out[0] = in[0];
-        out[1] = 0;
+        if constexpr (RealInput) {
+            out[1] = 0.0f;
+        } else {
+            out[1] = in[1];
+        }
         return;
     }
 
     const int half_N = N / 2;
-    if (N - half_N*2 == 1) {
-        dft(in, N, out);
+    if (N - half_N * 2 == 1) {
+        // Odd N: fall back to DFT
+        dft_impl<Inverse, RealInput>(cache, in, N, out);
         return;
     }
 
-    float* even = in + N;
-    for (int i = 0; i < half_N; ++i) {
-        even[i]= in[2*i];
-    }
-    float* even_fft = out + 2 * N;
-    fft(even, half_N, even_fft);
+    // Split into even and odd
+    if constexpr (RealInput) {
+        // Real input: stride is 1, copy only real values
+        float * even = in + N;
+        for (int i = 0; i < half_N; ++i) {
+            even[i] = in[2 * i];
+        }
+        float * even_fft = out + 2 * N;
+        fft_impl<Inverse, true>(cache, even, half_N, even_fft);
+
+        float * odd = even;
+        for (int i = 0; i < half_N; ++i) {
+            odd[i] = in[2 * i + 1];
+        }
+        float * odd_fft = even_fft + N;
+        fft_impl<Inverse, true>(cache, odd, half_N, odd_fft);
+    } else {
+        // Complex input: stride is 2, copy complex pairs
+        float * even = in + N * 2;
+        for (int i = 0; i < half_N; ++i) {
+            even[i * 2 + 0] = in[2 * i * 2 + 0];
+            even[i * 2 + 1] = in[2 * i * 2 + 1];
+        }
+        float * even_fft = out + 2 * N;
+        fft_impl<Inverse, false>(cache, even, half_N, even_fft);
 
-    float* odd = even;
-    for (int i = 0; i < half_N; ++i) {
-        odd[i] = in[2*i + 1];
+        float * odd = even;
+        for (int i = 0; i < half_N; ++i) {
+            odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
+            odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
+        }
+        float * odd_fft = even_fft + N;
+        fft_impl<Inverse, false>(cache, odd, half_N, odd_fft);
     }
-    float* odd_fft = even_fft + N;
-    fft(odd, half_N, odd_fft);
+
+    float * even_fft = out + 2 * N;
+    float * odd_fft  = even_fft + N;
 
     const int sin_cos_step = n_sin_cos_vals / N;
+
+    constexpr float sign  = Inverse ? 1.0f : -1.0f;
+    constexpr float scale = Inverse ? 0.5f : 1.0f;
+
     for (int k = 0; k < half_N; k++) {
-        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
-        float re =  g_cache.cos_vals[idx]; // cos(t)
-        float im = -g_cache.sin_vals[idx]; // sin(t)
+        int   idx = k * sin_cos_step;  // t = 2*M_PI*k/N
+        float re  = cache.cos_vals[idx];
+        float im  = sign * cache.sin_vals[idx];
 
-        float re_odd = odd_fft[2*k + 0];
-        float im_odd = odd_fft[2*k + 1];
+        float re_odd = odd_fft[2 * k + 0];
+        float im_odd = odd_fft[2 * k + 1];
 
-        out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
-        out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+        out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
+        out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
 
-        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
-        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+        out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
+        out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
     }
 }
 
+// Forward FFT for real input (used by mel spectrogram)
+static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    fft_impl<false, true>(cache, in, N, out);
+}
+
+// Inverse FFT for complex input
+static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    fft_impl<true, false>(cache, in, N, out);
+}
+
 struct filter_params {
     int32_t n_mel;
     int32_t n_fft_bins;
@@ -222,20 +265,27 @@ struct filter_params {
     bool    norm_per_feature = false;
 };
 
-static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
-                                              int n_samples, int frame_size, int frame_step, int n_threads,
-                                              const filter_params & params, mtmd_audio_mel & out) {
+static void log_mel_spectrogram_worker_thread(int                        ith,
+                                              const float *              hann,
+                                              const std::vector<float> & samples,
+                                              int                        n_samples,
+                                              int                        frame_size,
+                                              int                        frame_step,
+                                              int                        n_threads,
+                                              const filter_params &      params,
+                                              const mtmd_audio_cache &   cache,
+                                              mtmd_audio_mel &           out) {
     std::vector<float> fft_in(frame_size * 2, 0.0);
     std::vector<float> fft_out(frame_size * 2 * 2 * 2);
 
     int n_fft_bins = params.n_fft_bins;
     int i = ith;
 
-    const auto & filters = g_cache.filters;
+    const auto & filters = cache.filters;
 
     // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
     GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
-    GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
+    GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
     // calculate FFT only when fft_in are not all zero
     for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
         const int offset = i * frame_step;
@@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
         }
 
         // FFT
-        fft(fft_in.data(), frame_size, fft_out.data());
+        fft(cache, fft_in.data(), frame_size, fft_out.data());
 
         // Calculate modulus^2 of complex numbers
         // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@@ -298,6 +348,7 @@ static bool log_mel_spectrogram(
         const int     n_samples_in,
         const int     n_threads,
         const filter_params & params,
+        const mtmd_audio_cache & cache,
         mtmd_audio_mel & out) {
     //const int64_t t_start_us = ggml_time_us();
 
@@ -305,9 +356,9 @@ static bool log_mel_spectrogram(
     int n_samples = n_samples_in;
 
     // Hann window
-    const float * hann = g_cache.hann_window.data();
-    const int frame_size = (params.n_fft_bins - 1) * 2;
-    const int frame_step = params.hop_length;
+    const float * hann       = cache.hann_window.data();
+    const int     frame_size = (params.n_fft_bins - 1) * 2;
+    const int     frame_step = params.hop_length;
 
     // Padding
     std::vector<float> samples_padded;
@@ -335,9 +386,9 @@ static bool log_mel_spectrogram(
 
     // preemphasis
     if (params.preemph) {
-        const int pad_amount = frame_size / 2;
+        const int   pad_amount = frame_size / 2;
         const float preemph = 0.97f;
-        float prev = samples_padded[pad_amount];
+        float       prev = samples_padded[pad_amount];
         for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
             float cur = samples_padded[i];
             samples_padded[i] = cur - preemph * prev;
@@ -372,14 +423,14 @@ static bool log_mel_spectrogram(
     {
         std::vector<std::thread> workers(n_threads - 1);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
-            workers[iw] = std::thread(
-                    log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
-                    n_samples, frame_size, frame_step, n_threads,
-                    std::cref(params), std::ref(out));
+            workers[iw] =
+                std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
+                            frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
         }
 
         // main thread
-        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
+        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
+                                          cache, out);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw].join();
         }
@@ -404,7 +455,7 @@ static bool log_mel_spectrogram(
 
             for (int j = 0; j < effective_n_len; ++j) {
                 auto &value = out.data[i * out.n_len + j];
-                value = (value - mean) / mstd;
+                value        = (value - mean) / mstd;
             }
 
             // pad the rest with zeros
@@ -450,18 +501,14 @@ static bool log_mel_spectrogram(
 //
 
 void mtmd_audio_preprocessor_whisper::initialize() {
-    g_cache.fill_sin_cos_table(hparams.audio_n_fft);
-    g_cache.fill_hann_window(hparams.audio_window_len, true);
-    g_cache.fill_mel_filterbank_matrix(
-        hparams.n_mel_bins,
-        hparams.audio_n_fft,
-        hparams.audio_sample_rate);
+    cache.fill_sin_cos_table(hparams.audio_n_fft);
+    cache.fill_hann_window(hparams.audio_window_len, true);
+    cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
 }
 
-bool mtmd_audio_preprocessor_whisper::preprocess(
-        const float * samples,
-        size_t n_samples,
-        std::vector<mtmd_audio_mel> & output) {
+bool mtmd_audio_preprocessor_whisper::preprocess(const float *                 samples,
+                                                 size_t                        n_samples,
+                                                 std::vector<mtmd_audio_mel> & output) {
     if (n_samples == 0) {
         // empty audio
         return false;
@@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
     // if input is too short, pad with zeros
     // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
     // TODO: maybe handle this better
-    size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
+    size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1);  // +1 second margin
     if (n_samples < min_samples) {
         smpl.resize(min_samples, 0.0f);
         std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
@@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
     params.hop_length       = hparams.audio_hop_len;
     params.sample_rate      = hparams.audio_sample_rate;
     params.center_padding   = false;
-    params.preemph          = 0.0f; // disabled
+    params.preemph          = 0.0f;  // disabled
     params.use_natural_log  = false;
     params.norm_per_feature = false;
 
-    // make sure the global cache is initialized
-    GGML_ASSERT(!g_cache.sin_vals.empty());
-    GGML_ASSERT(!g_cache.cos_vals.empty());
-    GGML_ASSERT(!g_cache.filters.data.empty());
+    // make sure the cache is initialized
+    GGML_ASSERT(!cache.sin_vals.empty());
+    GGML_ASSERT(!cache.cos_vals.empty());
+    GGML_ASSERT(!cache.filters.data.empty());
 
     mtmd_audio_mel out_full;
-    bool ok = log_mel_spectrogram(
-                samples,
-                n_samples,
-                4, // n_threads
-                params,
-                out_full);
+    bool           ok = log_mel_spectrogram(samples, n_samples,
+                                            4,  // n_threads
+                                            params, cache, out_full);
     if (!ok) {
         return false;
     }
@@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
         printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
     }
     const size_t frames_per_chunk = 3000;
-    GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
-    for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
-        int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
-        if ((size_t)n_len < frames_per_chunk) {
-            break; // last uncomplete chunk will always be a padded chunk, safe to ignore
+    GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
+    for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
+        int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
+        if ((size_t) n_len < frames_per_chunk) {
+            break;  // last uncomplete chunk will always be a padded chunk, safe to ignore
         }
 
         mtmd_audio_mel out_chunk;
         out_chunk.n_len     = n_len;
         out_chunk.n_mel     = out_full.n_mel;
-        out_chunk.n_len_org = out_full.n_mel; // unused
+        out_chunk.n_len_org = out_full.n_mel;  // unused
         out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
 
         for (int i = 0; i < out_full.n_mel; i++) {
-            auto src = out_full.data.begin() + i*out_full.n_len + off;
+            auto src = out_full.data.begin() + i * out_full.n_len + off;
             out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
         }
 
@@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
 //
 
 void mtmd_audio_preprocessor_conformer::initialize() {
-    g_cache.fill_sin_cos_table(hparams.audio_n_fft);
-    g_cache.fill_hann_window(hparams.audio_window_len, true);
-    g_cache.fill_mel_filterbank_matrix(
-        hparams.n_mel_bins,
-        hparams.audio_n_fft,
-        hparams.audio_sample_rate);
+    cache.fill_sin_cos_table(hparams.audio_n_fft);
+    cache.fill_hann_window(hparams.audio_window_len, true);
+    cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
 }
 
-bool mtmd_audio_preprocessor_conformer::preprocess(
-        const float * samples,
-        size_t n_samples,
-        std::vector<mtmd_audio_mel> & output) {
+bool mtmd_audio_preprocessor_conformer::preprocess(const float *                 samples,
+                                                   size_t                        n_samples,
+                                                   std::vector<mtmd_audio_mel> & output) {
     // empty audio
     if (n_samples == 0) {
         return false;
@@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
     params.use_natural_log  = true;
     params.norm_per_feature = true;
 
-    // make sure the global cache is initialized
-    GGML_ASSERT(!g_cache.sin_vals.empty());
-    GGML_ASSERT(!g_cache.cos_vals.empty());
-    GGML_ASSERT(!g_cache.filters.data.empty());
+    // make sure the cache is initialized
+    GGML_ASSERT(!cache.sin_vals.empty());
+    GGML_ASSERT(!cache.cos_vals.empty());
+    GGML_ASSERT(!cache.filters.data.empty());
 
     mtmd_audio_mel out_full;
-    bool ok = log_mel_spectrogram(
-                samples,
-                n_samples,
-                4, // n_threads
-                params,
-                out_full);
+    bool           ok = log_mel_spectrogram(samples, n_samples,
+                                            4,  // n_threads
+                                            params, cache, out_full);
     if (!ok) {
         return false;
     }
@@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
     output.push_back(std::move(out_full));
     return true;
 }
+
+//
+// mtmd_audio_streaming_istft implementation
+//
+
+mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
+    n_fft(n_fft),
+    hop_length(hop_length),
+    n_fft_bins(n_fft / 2 + 1),
+    overlap_buffer(n_fft, 0.0f),
+    window_sum_buffer(n_fft, 0.0f),
+    padding_to_remove((n_fft - hop_length) / 2),
+    ifft_in(n_fft * 2 * 4, 0.0f),  // extra space for recursive IFFT
+    ifft_out(n_fft * 2 * 4, 0.0f) {
+    cache.fill_sin_cos_table(n_fft);
+    cache.fill_hann_window(n_fft, true);
+}
+
+void mtmd_audio_streaming_istft::reset() {
+    std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
+    std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
+    padding_to_remove = (n_fft - hop_length) / 2;
+}
+
+std::vector<float> mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
+    std::vector<float> output(hop_length);
+
+    // copy frequencies
+    for (int j = 0; j < n_fft_bins; j++) {
+        ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
+        ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
+    }
+
+    // mirror negative frequencies
+    for (int j = 1; j < n_fft_bins - 1; j++) {
+        int mirror_idx              = n_fft - j;
+        ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
+        ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1];  // conjugate
+    }
+
+    ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
+
+    // update window sum and overlap buffer
+    for (int j = 0; j < n_fft; j++) {
+        window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
+        overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
+    }
+
+    // extract hop_length samples with normalization
+    for (int i = 0; i < hop_length; i++) {
+        if (window_sum_buffer[i] > 1e-8f) {
+            output[i] = overlap_buffer[i] / window_sum_buffer[i];
+        } else {
+            output[i] = overlap_buffer[i];
+        }
+    }
+
+    // shift buffers left by hop_length
+    std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
+    std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
+
+    std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
+    std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
+
+    // Remove padding if needed
+    int to_remove = std::min(padding_to_remove, (int) output.size());
+    padding_to_remove -= to_remove;
+    output.erase(output.begin(), output.begin() + to_remove);
+
+    return output;
+}
+
+std::vector<float> mtmd_audio_streaming_istft::flush() {
+    std::vector<float> output;
+
+    // Extract remaining samples from overlap buffer
+    // Continue until we've extracted all meaningful samples
+    int remaining = n_fft - hop_length;
+    while (remaining > 0) {
+        int chunk_size = std::min(remaining, hop_length);
+
+        for (int i = 0; i < chunk_size; i++) {
+            float sample;
+            if (window_sum_buffer[i] > 1e-8f) {
+                sample = overlap_buffer[i] / window_sum_buffer[i];
+            } else {
+                sample = overlap_buffer[i];
+            }
+            output.push_back(sample);
+        }
+
+        // Shift buffers
+        std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
+        std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
+
+        std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
+        std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
+
+        remaining -= chunk_size;
+    }
+
+    return output;
+}
index d484c9d0301a9821648eefd25c121e343a77da1c..016c7392e4fcbbff774578947c47d9f8454ddab0 100644 (file)
@@ -17,6 +17,38 @@ struct mtmd_audio_mel {
     std::vector<float> data;
 };
 
+struct mtmd_audio_mel_filters {
+    int32_t n_mel;
+    int32_t n_fft;
+
+    std::vector<float> data;
+};
+
+// cache for audio processing, each processor instance owns its own cache
+struct mtmd_audio_cache {
+    std::vector<float> sin_vals;
+    std::vector<float> cos_vals;
+
+    std::vector<float> hann_window;
+
+    mtmd_audio_mel_filters filters;
+
+    void fill_sin_cos_table(int n);
+
+    void fill_hann_window(int length, bool periodic);
+
+    // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
+    // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
+    void fill_mel_filterbank_matrix(int   n_mel,
+                                    int   n_fft,
+                                    int   sample_rate,               // e.g. 16000
+                                    float fmin             = 0.0f,   // e.g. 0.0
+                                    float fmax             = -1.0f,  // e.g. sr/2; pass -1 for auto
+                                    bool  slaney_area_norm = true,
+                                    float scale = 1.0f  // optional extra scaling
+    );
+};
+
 struct mtmd_audio_preprocessor {
     const clip_hparams & hparams;
 
@@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
     mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
     void initialize() override;
     bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
+
+  private:
+    mtmd_audio_cache cache;
 };
 
 struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
     mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
     void initialize() override;
     bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
+
+  private:
+    mtmd_audio_cache cache;
+};
+
+//
+// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
+//
+struct mtmd_audio_streaming_istft {
+    mtmd_audio_streaming_istft(int n_fft, int hop_length);
+
+    // reset streaming state
+    void reset();
+
+    // process a single STFT frame (streaming)
+    // frame_spectrum: [n_fft_bins x 2] interleaved real/imag
+    // returns: up to hop_length samples
+    std::vector<float> process_frame(const float * frame_spectrum);
+
+    // flush remaining samples at end of stream
+    std::vector<float> flush();
+
+  private:
+    int n_fft;
+    int hop_length;
+    int n_fft_bins;
+
+    // Own cache for output processing
+    mtmd_audio_cache cache;
+
+    // Streaming state
+    std::vector<float> overlap_buffer;
+    std::vector<float> window_sum_buffer;
+    int                padding_to_remove;
+
+    // Working buffers for IFFT
+    std::vector<float> ifft_in;
+    std::vector<float> ifft_out;
 };