#include "whisper-mel-cuda.hpp"
#include "whisper.h"
+#include <ggml-cuda/common.cuh>
+#include <ggml-backend-impl.h>
+
#include <cuda.h>
#include <cuda_runtime.h>
#include <cufft.h>
#pragma warning(disable: 4324) // added padding
#endif
-#ifndef NDEBUG
-# define DO_CHECKS 1
-#else
-# define DO_CHECKS 0
-#endif
-
namespace {
-#if DO_CHECKS
-const char* cufftGetErrorString(cufftResult_t res) {
+static const char* cufftGetErrorString(cufftResult_t res) {
switch (res) {
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
}
}
-# define CUDA_CHECK_GEN(err, success, error_fn) \
- do { \
- auto err_ = (err); \
- if (err_ != (success)) { \
- fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
- } \
- } while (0)
-#else
-# define CUDA_CHECK_GEN(err, success, error_fn) err
-#endif
-
-#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
-#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
__global__ void k_fill_stft_input(
}
__global__ void k_calc_magnitudes(
- const cuComplex* stft_out,
+ const cuComplex * stft_out,
const int n_frames,
float * magnitudes
) {
}
void calc_magnitudes(
- const cuComplex* stft_out,
+ const cuComplex * stft_out,
int n_frames,
float * magnitudes,
cudaStream_t stream
const int m_n_mel;
ggml_backend_t m_backend = nullptr;
+ int m_device = -1;
cudaStream_t m_stream = nullptr;
cublasHandle_t m_cublas_handle = nullptr;
: m_n_mel(filters.n_mel)
, m_backend(backend)
{
+ ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context;
+ m_device = cuda_ctx->device;
+
+ if (ggml_cuda_info().devices[m_device].cc < 600) {
+ // we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
+ // https://github.com/ggerganov/whisper.cpp/issues/2230
+ // to be safe forbid anything below 6.0
+ throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
+ }
+
+ ggml_cuda_set_device(m_device);
+
if (filters.n_fft != WHISPER_N_FFT_HALF) {
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
}
}
~mel_calc_cuda() {
+ ggml_cuda_set_device(m_device);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
CUDA_CHECK(cudaStreamDestroy(m_stream));
CUDA_CHECK(cudaFree(m_hann_window));
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
+ ggml_cuda_set_device(m_device);
ensure_working_areas(samples.len);
const size_t mirror_pad = WHISPER_N_FFT / 2;
}
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
- if (filters.n_fft != WHISPER_N_FFT_HALF) {
+ try {
+ return new mel_calc_cuda(backend, filters);
+ }
+ catch (...) {
+ // TODO: log error (but for this we would have to expose the log state to be accessible here)
return nullptr;
}
- return new mel_calc_cuda(backend, filters);
}
#if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
if (ggml_backend_is_cuda(backend)) {
auto ret = whisper_mel_calc_create_cuda(backend, filters);
- // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
- const float warmup[256] = {0};
- ret->calculate({warmup, 256}, 1);
- return ret;
- } else
+ if (ret) {
+ // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
+ const float warmup[256] = { 0 };
+ ret->calculate({ warmup, 256 }, 1);
+ return ret;
+ }
+ }
#endif
- return new mel_calc_cpu(backend, filters);
+
+ // a specialized mel_calc could not be created
+ // fall back to CPU
+ return new mel_calc_cpu(backend, filters);
}
// split text into tokens