]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : use ggml-cuda in mel calc, set appropriate device (#2236)
authorBorislav Stanimirov <redacted>
Thu, 13 Jun 2024 10:16:07 +0000 (13:16 +0300)
committerGitHub <redacted>
Thu, 13 Jun 2024 10:16:07 +0000 (13:16 +0300)
* whisper : use ggml-cuda in mel calc, set appropriate device

* whisper : forbid cuda mel calc on devices with compute < 600, workaround for #2230

whisper-mel-cuda.cu
whisper.cpp

index cc44556f4263f9ba429c4d849c6378dc4bacc491..8d6741427314908cc9a5bfd31f6607a531e0d0f4 100644 (file)
@@ -2,6 +2,9 @@
 #include "whisper-mel-cuda.hpp"
 #include "whisper.h"
 
+#include <ggml-cuda/common.cuh>
+#include <ggml-backend-impl.h>
+
 #include <cuda.h>
 #include <cuda_runtime.h>
 #include <cufft.h>
 #pragma warning(disable: 4324) // added padding
 #endif
 
-#ifndef NDEBUG
-#   define DO_CHECKS 1
-#else
-#   define DO_CHECKS 0
-#endif
-
 namespace {
 
-#if DO_CHECKS
-const char* cufftGetErrorString(cufftResult_t res) {
+static const char* cufftGetErrorString(cufftResult_t res) {
     switch (res) {
     case CUFFT_SUCCESS: return "The cuFFT operation was successful";
     case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
@@ -48,19 +44,6 @@ const char* cufftGetErrorString(cufftResult_t res) {
     }
 }
 
-#   define CUDA_CHECK_GEN(err, success, error_fn)                                            \
-         do {                                                                                \
-            auto err_ = (err);                                                               \
-            if (err_ != (success)) {                                                         \
-                fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
-            }                                                                                \
-        } while (0)
-#else
-#   define CUDA_CHECK_GEN(err, success, error_fn) err
-#endif
-
-#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
-#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
 #define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
 
 __global__ void k_fill_stft_input(
@@ -81,7 +64,7 @@ __global__ void k_fill_stft_input(
 }
 
 __global__ void k_calc_magnitudes(
-    const cuComplex* stft_out,
+    const cuComplex * stft_out,
     const int n_frames,
     float * magnitudes
 ) {
@@ -133,7 +116,7 @@ void fill_stft_input(
 }
 
 void calc_magnitudes(
-    const cuComplex* stft_out,
+    const cuComplex * stft_out,
     int n_frames,
     float * magnitudes,
     cudaStream_t stream
@@ -169,6 +152,7 @@ class mel_calc_cuda : public whisper_mel_calc {
     const int m_n_mel;
 
     ggml_backend_t m_backend = nullptr;
+    int m_device = -1;
 
     cudaStream_t m_stream = nullptr;
     cublasHandle_t m_cublas_handle = nullptr;
@@ -190,6 +174,18 @@ public:
         : m_n_mel(filters.n_mel)
         , m_backend(backend)
     {
+        ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context;
+        m_device = cuda_ctx->device;
+
+        if (ggml_cuda_info().devices[m_device].cc < 600) {
+            // we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
+            // https://github.com/ggerganov/whisper.cpp/issues/2230
+            // to be safe forbid anything below 6.0
+            throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
+        }
+
+        ggml_cuda_set_device(m_device);
+
         if (filters.n_fft != WHISPER_N_FFT_HALF) {
             throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
         }
@@ -219,6 +215,7 @@ public:
     }
 
     ~mel_calc_cuda() {
+        ggml_cuda_set_device(m_device);
         CUDA_CHECK(cudaStreamSynchronize(m_stream));
         CUDA_CHECK(cudaStreamDestroy(m_stream));
         CUDA_CHECK(cudaFree(m_hann_window));
@@ -268,6 +265,7 @@ public:
     }
 
     virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
+        ggml_cuda_set_device(m_device);
         ensure_working_areas(samples.len);
 
         const size_t mirror_pad = WHISPER_N_FFT / 2;
@@ -356,8 +354,11 @@ public:
 }
 
 whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
-    if (filters.n_fft != WHISPER_N_FFT_HALF) {
+    try {
+        return new mel_calc_cuda(backend, filters);
+    }
+    catch (...) {
+        // TODO: log error (but for this we would have to expose the log state to be accessible here)
         return nullptr;
     }
-    return new mel_calc_cuda(backend, filters);
 }
index a08f15ff1026476193144bfd4303c8e56635bc79..71eb06733ec06fb5d63065207152eaeca4b80be0 100644 (file)
@@ -3170,13 +3170,18 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
 #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
     if (ggml_backend_is_cuda(backend)) {
         auto ret = whisper_mel_calc_create_cuda(backend, filters);
-        // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
-        const float warmup[256] = {0};
-        ret->calculate({warmup, 256}, 1);
-        return ret;
-    } else
+        if (ret) {
+            // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
+            const float warmup[256] = { 0 };
+            ret->calculate({ warmup, 256 }, 1);
+            return ret;
+        }
+    }
 #endif
-        return new mel_calc_cpu(backend, filters);
+
+    // a specialized mel_calc could not be created
+    // fall back to CPU
+    return new mel_calc_cpu(backend, filters);
 }
 
 // split text into tokens