]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : calculate mel spectrogram directly into a ggml_tensor (#2208)
authorBorislav Stanimirov <redacted>
Thu, 6 Jun 2024 13:20:46 +0000 (16:20 +0300)
committerGitHub <redacted>
Thu, 6 Jun 2024 13:20:46 +0000 (16:20 +0300)
* whisper : calculate mel spectrogram directly into a ggml_tensor

* whisper : remove unused temp buffer from state

* whisper : fix not initializing wstate.embd_enc

whisper-mel-cuda.cu
whisper-mel.hpp
whisper.cpp

index ad36cae58300cedc7e398047a303f7d21bb95609..3f3e3158d3e26f4019bba679eec26197fcb1c621 100644 (file)
@@ -8,6 +8,7 @@
 #include <cublas_v2.h>
 #include <cuComplex.h>
 #include <cub/device/device_reduce.cuh>
+#include <device_launch_parameters.h>
 
 #include <algorithm>
 
@@ -301,27 +302,23 @@ public:
             &fzero,
             mel_data, int(n_mag_frames)));
 
-        float * log_mels = nullptr;
-        CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream));
+        whisper_mel ret;
+        // Calculate semi-padded sample length to ensure compatibility
+        int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
+        ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel);
+        assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
+
+        float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
 
         calc_log_mel(
             mel_data, int(m_n_mel * n_mag_frames),
-            m_log_mel_temp_storage, int(m_log_mel_temp_storage_size),
+            m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
             log_mels, m_stream);
 
-        whisper_mel ret;
-        ret.n_mel = m_n_mel;
-        ret.n_len = int(n_mag_frames);
-        // Calculate semi-padded sample length to ensure compatibility
-        ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
-        ret.data.resize(m_n_mel * n_mag_frames);
-        CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream));
-
         CUDA_CHECK(cudaStreamSynchronize(m_stream));
 
         // cleanup
         CUFFT_CHECK(cufftDestroy(plan));
-        CUDA_CHECK(cudaFreeAsync(log_mels, m_stream));
         CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
         CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
         CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
index bc48475feecf6b5c8d1d51bc2cc03b8aa30d3c56..e52b804d9bcbd562381a091983e8c6492fc40856 100644 (file)
@@ -3,11 +3,23 @@
 #include <vector>
 
 struct whisper_mel {
-    int n_len;
-    int n_len_org;
-    int n_mel;
+    int n_len_org = 0;
 
-    std::vector<float> data;
+    ggml_tensor * tensor = nullptr;
+    ggml_context * ctx = nullptr;
+    ggml_backend_buffer_t buffer = nullptr;
+
+    whisper_mel() = default;
+    ~whisper_mel();
+
+    whisper_mel(const whisper_mel &) = delete;
+    whisper_mel & operator=(const whisper_mel &) = delete;
+    whisper_mel(whisper_mel &&) noexcept;
+    whisper_mel & operator=(whisper_mel &&) noexcept;
+
+    void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
+    void reset();
+    void take(whisper_mel & other) noexcept;
 };
 
 struct whisper_filters {
index 2dd2f591bd804bcd99d6cb01062c2ba7827dfd0a..dfbcc9d39a0ef26f07da4479135ccad9cbf0afb8 100644 (file)
@@ -821,7 +821,6 @@ struct whisper_state {
     struct ggml_tensor * embd_enc  = nullptr;
 
     // helpers for GPU offloading
-    std::vector<float> inp_mel;
     std::vector<float> inp_mask;
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
@@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
 
 static struct ggml_cgraph * whisper_build_graph_conv(
         whisper_context & wctx,
-          whisper_state & wstate) {
+          whisper_state & wstate,
+              const int   mel_offset) {
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
@@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
 
     ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
-    ggml_set_name(mel, "mel");
-    ggml_set_input(mel);
+    ggml_tensor * mel_inp = wstate.mel.tensor;
+    ggml_tensor * mel;
+    if (mel_inp) {
+        const int n_len = int(mel_inp->ne[0]);
+        const int out_s = 2 * n_ctx;
+        const int i0 = std::min(mel_offset, n_len);
+        const int i1 = std::min(mel_offset + out_s, n_len);
+        const int mel_s = i1 - i0;
+
+        assert(mel_inp->type == GGML_TYPE_F32);
+        assert(mel_inp->ne[1] == n_mels);
+
+        ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
+
+        if (mel_s < out_s) {
+            mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
+        }
+        else {
+            mel = ggml_cont(ctx0, cur);
+        }
+    }
+    else {
+        // just create some tensor so that the graph/buffer size estimation is correct
+        mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
+    }
+    ggml_set_name(mel, "mel"); // used with external encoding
 
     struct ggml_tensor * cur = nullptr;
 
@@ -2218,45 +2241,21 @@ static bool whisper_encode_internal(
     {
         auto & alloc = wstate.alloc_conv.alloc;
 
-        ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
+        ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
 
         if (!ggml_gallocr_alloc_graph(alloc, gf)) {
             // should never happen as we pre-allocate the memory
             return false;
         }
 
-        struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
-
-        // set the input
-        {
-            const auto & mel_inp = wstate.mel;
-            const int n_ctx      = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
-
-            assert(mel->type == GGML_TYPE_F32);
-            assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
-
-            wstate.inp_mel.resize(ggml_nelements(mel));
-
-            float * dst = wstate.inp_mel.data();
-            memset(dst, 0, ggml_nbytes(mel));
-
-            const int i0 = std::min(mel_offset,           mel_inp.n_len);
-            const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
-
-            for (int j = 0; j < mel_inp.n_mel; ++j) {
-                for (int i = i0; i < i1; ++i) {
-                    dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
-                }
-            }
-
-            ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            return false;
         }
 
-        if (!whisper_encode_external(wstate)) {
-            if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
-                return false;
-            }
-        } else {
+        if (whisper_encode_external(wstate)) {
+            ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
+            assert(mel->ne[1] == wctx.model.hparams.n_mels);
+            GGML_UNUSED(mel);
 #if defined(WHISPER_USE_COREML)
             whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
 #elif defined(WHISPER_USE_OPENVINO)
@@ -2886,6 +2885,54 @@ struct whisper_global_cache {
 
 // Mel spectrogram
 
+whisper_mel::~whisper_mel() {
+    reset();
+}
+
+whisper_mel::whisper_mel(whisper_mel && other) noexcept {
+    take(other);
+}
+
+whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept {
+    if (this != &other) {
+        reset();
+        take(other);
+    }
+    return *this;
+}
+
+void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
+    this->n_len_org = n_len_org;
+    assert(!ctx);
+    ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
+    tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
+    buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
+    auto alloc = ggml_tallocr_new(buffer);
+    ggml_tallocr_alloc(&alloc, tensor);
+}
+
+void whisper_mel::reset() {
+    ggml_free(ctx);
+    ggml_backend_buffer_free(buffer);
+
+    n_len_org = 0;
+    tensor = nullptr;
+    ctx = nullptr;
+    buffer = nullptr;
+}
+
+void whisper_mel::take(whisper_mel & other) noexcept {
+    n_len_org = other.n_len_org;
+    tensor = other.tensor;
+    ctx = other.ctx;
+    buffer = other.buffer;
+
+    other.n_len_org = 0;
+    other.tensor = nullptr;
+    other.ctx = nullptr;
+    other.buffer = nullptr;
+}
+
 whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
 
 whisper_span<const float> whisper_mel_calc::hann_window() {
@@ -2973,9 +3020,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
     }
 }
 
-static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
+namespace {
+
+struct whisper_mel_data {
+    int n_len;
+    int n_len_org;
+    int n_mel;
+    float* data;
+};
+
+void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
                                               int n_samples, int n_threads,
-                                              const whisper_filters & filters, whisper_mel & mel) {
+                                              const whisper_filters & filters, whisper_mel_data & mel) {
     const auto frame_size = WHISPER_N_FFT;
     const auto frame_step = WHISPER_HOP_LENGTH;
     std::vector<float> fft_in(frame_size, 0.0);
@@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
         }
     }
 }
-namespace {
+
 struct mel_calc_cpu : public whisper_mel_calc {
+    ggml_backend_t m_backend;
     const whisper_filters& m_filters;
-    mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {}
+    mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
 
     // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
     whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
@@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc {
         // reflective pad 200 samples at the beginning of audio
         std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
 
-        whisper_mel mel;
+        whisper_mel_data mel;
         mel.n_mel     = m_filters.n_mel;
         // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
         // Calculate number of frames + remove the last frame
         mel.n_len     = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
         // Calculate semi-padded sample length to ensure compatibility
         mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
-        mel.data.resize(mel.n_mel * mel.n_len);
 
+        std::vector<float> host_mel_data;
+
+        whisper_mel ret;
+        ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
+        if (ggml_backend_buffer_is_host(ret.buffer)) {
+            mel.data = reinterpret_cast<float*>(ret.tensor->data);
+        } else {
+            host_mel_data.resize(mel.n_len * mel.n_mel);
+            mel.data = host_mel_data.data();
+        }
 
         {
             std::vector<std::thread> workers(n_threads - 1);
@@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc {
             mel.data[i] = (mel.data[i] + 4.0)/4.0;
         }
 
-        return mel;
+        if (!host_mel_data.empty()) {
+            // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
+            ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
+        }
+
+        return ret;
     }
 };
 }
@@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
         return ret;
     } else
 #endif
-        return new mel_calc_cpu(filters);
+        return new mel_calc_cpu(backend, filters);
 }
 
 // split text into tokens
@@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     {
         bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
                 [&]() {
-                    return whisper_build_graph_conv(*ctx, *state);
+                    return whisper_build_graph_conv(*ctx, *state, 0);
                 });
 
         if (!ok) {
@@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state(
         return -1;
     }
 
-    state->mel.n_len     = n_len;
-    state->mel.n_len_org = n_len;
-    state->mel.n_mel     = n_mel;
-
-    state->mel.data.resize(n_len*n_mel);
-    memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
+    state->mel.reset();
+    state->mel.init(ctx->backend, n_len, n_len, n_mel);
+    ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
 
     return 0;
 }