]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : whisper_state/backend fixes (#2217)
authorGeorgi Gerganov <redacted>
Thu, 6 Jun 2024 15:51:36 +0000 (18:51 +0300)
committerGitHub <redacted>
Thu, 6 Jun 2024 15:51:36 +0000 (18:51 +0300)
* whisper : fixes

* ci : WHISPER_CUBLAS -> WHISPER_CUDA

.github/workflows/build.yml
whisper-mel-cuda.cu
whisper-mel.hpp
whisper.cpp

index e9bf9c2829231f5372c1ced684a09bc2540872d5..2095e70d175d29968984ab7c596e6dd6a7c5caff 100644 (file)
@@ -498,7 +498,7 @@ jobs:
         run: >
           cmake -S . -B ./build -A ${{ matrix.arch }}
           -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-          -DWHISPER_CUBLAS=${{ matrix.cublas }}
+          -DWHISPER_CUDA=${{ matrix.cublas }}
           -DWHISPER_SDL2=${{ matrix.sdl2 }}
 
       - name: Build ${{ matrix.cuda-toolkit }}
index 3f3e3158d3e26f4019bba679eec26197fcb1c621..9a6f1093f8ef18c45b7059866913bf25b37c040d 100644 (file)
@@ -194,7 +194,7 @@ class mel_calc_cuda : public whisper_mel_calc {
     size_t m_log_mel_temp_storage_size = 0;
     void * m_log_mel_temp_storage = nullptr;
 public:
-    mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters)
+    mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
         : m_n_mel(filters.n_mel)
         , m_backend(backend)
     {
@@ -305,7 +305,7 @@ public:
         whisper_mel ret;
         // Calculate semi-padded sample length to ensure compatibility
         int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
-        ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel);
+        whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
         assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
 
         float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
index e52b804d9bcbd562381a091983e8c6492fc40856..1a54a23c7307e91a9c840e3fe75ebb4f77f1b0c7 100644 (file)
@@ -5,22 +5,14 @@
 struct whisper_mel {
     int n_len_org = 0;
 
-    ggml_tensor * tensor = nullptr;
     ggml_context * ctx = nullptr;
+    ggml_tensor * tensor = nullptr;
     ggml_backend_buffer_t buffer = nullptr;
+};
 
-    whisper_mel() = default;
-    ~whisper_mel();
-
-    whisper_mel(const whisper_mel &) = delete;
-    whisper_mel & operator=(const whisper_mel &) = delete;
-    whisper_mel(whisper_mel &&) noexcept;
-    whisper_mel & operator=(whisper_mel &&) noexcept;
+void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
 
-    void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
-    void reset();
-    void take(whisper_mel & other) noexcept;
-};
+void whisper_mel_free(whisper_mel & mel);
 
 struct whisper_filters {
     int32_t n_mel;
@@ -40,6 +32,3 @@ struct whisper_mel_calc {
     virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) const = 0;
     static whisper_span<const float> hann_window();
 };
-
-// returns a new pointer which needs to be freed with delete
-whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters);
index dfbcc9d39a0ef26f07da4479135ccad9cbf0afb8..e8a1320898bb7a4ac0928afb67d4b06552871180 100644 (file)
@@ -801,6 +801,7 @@ struct whisper_state {
     whisper_kv_cache kv_pad;
 
     whisper_mel mel;
+    whisper_mel_calc * mel_calc = nullptr;
 
     whisper_batch batch;
 
@@ -870,8 +871,6 @@ struct whisper_context {
     whisper_model model;
     whisper_vocab vocab;
 
-    whisper_mel_calc * mel_calc = nullptr;
-
     whisper_state * state = nullptr;
 
     ggml_backend_t backend = nullptr;
@@ -893,7 +892,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
     BYTESWAP_VALUE(dest);
 }
 
-static bool kv_cache_init(
+static bool whisper_kv_cache_init(
              struct whisper_kv_cache & cache,
                       ggml_backend_t   backend,
                            ggml_type   wtype,
@@ -936,7 +935,7 @@ static bool kv_cache_init(
     return true;
 }
 
-static void kv_cache_free(struct whisper_kv_cache & cache) {
+static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
     ggml_free(cache.ctx);
     ggml_backend_buffer_free(cache.buffer);
     cache.ctx = nullptr;
@@ -1250,9 +1249,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
     }
 #endif
 
+    GGML_UNUSED(params);
+
     if (backend_gpu) {
         return backend_gpu;
     }
+
     return ggml_backend_cpu_init();
 }
 
@@ -2885,52 +2887,25 @@ struct whisper_global_cache {
 
 // Mel spectrogram
 
-whisper_mel::~whisper_mel() {
-    reset();
-}
-
-whisper_mel::whisper_mel(whisper_mel && other) noexcept {
-    take(other);
-}
-
-whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept {
-    if (this != &other) {
-        reset();
-        take(other);
-    }
-    return *this;
-}
-
-void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
-    this->n_len_org = n_len_org;
-    assert(!ctx);
-    ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
-    tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
-    buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
-    auto alloc = ggml_tallocr_new(buffer);
-    ggml_tallocr_alloc(&alloc, tensor);
-}
-
-void whisper_mel::reset() {
-    ggml_free(ctx);
-    ggml_backend_buffer_free(buffer);
-
-    n_len_org = 0;
-    tensor = nullptr;
-    ctx = nullptr;
-    buffer = nullptr;
+void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
+    WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel);
+    mel.n_len_org = n_len_org;
+    assert(!mel.ctx);
+    mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
+    mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel);
+    mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend));
+    auto alloc = ggml_tallocr_new(mel.buffer);
+    ggml_tallocr_alloc(&alloc, mel.tensor);
 }
 
-void whisper_mel::take(whisper_mel & other) noexcept {
-    n_len_org = other.n_len_org;
-    tensor = other.tensor;
-    ctx = other.ctx;
-    buffer = other.buffer;
+void whisper_mel_free(whisper_mel & mel) {
+    ggml_free(mel.ctx);
+    ggml_backend_buffer_free(mel.buffer);
 
-    other.n_len_org = 0;
-    other.tensor = nullptr;
-    other.ctx = nullptr;
-    other.buffer = nullptr;
+    mel.n_len_org = 0;
+    mel.ctx = nullptr;
+    mel.tensor = nullptr;
+    mel.buffer = nullptr;
 }
 
 whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
@@ -3026,7 +3001,7 @@ struct whisper_mel_data {
     int n_len;
     int n_len_org;
     int n_mel;
-    float* data;
+    float * data;
 };
 
 void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
@@ -3100,7 +3075,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
 
 struct mel_calc_cpu : public whisper_mel_calc {
     ggml_backend_t m_backend;
-    const whisper_filters& m_filters;
+    const whisper_filters & m_filters;
     mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
 
     // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
@@ -3137,7 +3112,7 @@ struct mel_calc_cpu : public whisper_mel_calc {
         std::vector<float> host_mel_data;
 
         whisper_mel ret;
-        ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
+        whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
         if (ggml_backend_buffer_is_host(ret.buffer)) {
             mel.data = reinterpret_cast<float*>(ret.tensor->data);
         } else {
@@ -3325,15 +3300,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         return nullptr;
     }
 
+    state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
+
     // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
     // in theory, there can be a case where this is not enough, but in practice it should always be enough
     const int factor = 3;
 
-    if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
                 ctx->model.hparams.n_text_state,
                 ctx->model.hparams.n_text_layer,
                 GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
-        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
     }
@@ -3343,11 +3320,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
                 ctx->model.hparams.n_text_state,
                 ctx->model.hparams.n_text_layer,
                 GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
-        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
     }
@@ -3357,11 +3334,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
                 ctx->model.hparams.n_audio_state,
                 1,
                 GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
-        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
     }
@@ -3373,7 +3350,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // [EXPERIMENTAL] Token-level timestamps with DTW
     if (ctx->params.dtw_token_timestamps) {
-        if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
+        if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
             WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
             whisper_free_state(state);
             return nullptr;
@@ -3416,7 +3393,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // conv allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
+        bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
                 [&]() {
                     return whisper_build_graph_conv(*ctx, *state, 0);
                 });
@@ -3432,7 +3409,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // encoder allocator
     if (!whisper_encode_external(*state)) {
-        bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
+        bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
                 [&]() {
                     return whisper_build_graph_encoder(*ctx, *state);
                 });
@@ -3448,7 +3425,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // cross allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
+        bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
                 [&]() {
                     return whisper_build_graph_cross(*ctx, *state);
                 });
@@ -3464,7 +3441,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // decoder allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
+        bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
                 [&]() {
                     const auto & hparams = ctx->model.hparams;
 
@@ -3660,8 +3637,6 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
         return nullptr;
     }
 
-    ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters);
-
     loader->close(loader->context);
 
     return ctx;
@@ -3738,9 +3713,14 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
 
 void whisper_free_state(struct whisper_state * state) {
     if (state) {
-        kv_cache_free(state->kv_self);
-        kv_cache_free(state->kv_cross);
-        kv_cache_free(state->kv_pad);
+        whisper_kv_cache_free(state->kv_self);
+        whisper_kv_cache_free(state->kv_cross);
+        whisper_kv_cache_free(state->kv_pad);
+
+        whisper_mel_free(state->mel);
+
+        delete state->mel_calc;
+        state->mel_calc = nullptr;
 
 #ifdef WHISPER_USE_COREML
         if (state->ctx_coreml != nullptr) {
@@ -3782,8 +3762,6 @@ void whisper_free(struct whisper_context * ctx) {
 
         ggml_backend_free(ctx->backend);
 
-        delete ctx->mel_calc;
-        ctx->mel_calc = nullptr;
         delete ctx;
     }
 }
@@ -3800,9 +3778,11 @@ void whisper_free_params(struct whisper_full_params * params) {
     }
 }
 
-int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
+int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
     const int64_t t_start_us = ggml_time_us();
-    state->mel = ctx->mel_calc->calculate({samples, n_samples}, n_threads);
+
+    state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
+
     state->t_mel_us += ggml_time_us() - t_start_us;
 
     // Dump log_mel_spectrogram
@@ -3834,8 +3814,9 @@ int whisper_set_mel_with_state(
         return -1;
     }
 
-    state->mel.reset();
-    state->mel.init(ctx->backend, n_len, n_len, n_mel);
+    whisper_mel_free(state->mel);
+    whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
+
     ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
 
     return 0;