]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : expose CUDA device setting in public API (#1840)
authorDidzis Gosko <redacted>
Fri, 9 Feb 2024 15:27:47 +0000 (17:27 +0200)
committerGitHub <redacted>
Fri, 9 Feb 2024 15:27:47 +0000 (17:27 +0200)
* Makefile : allow to override CUDA_ARCH_FLAG

* whisper : allow to select GPU (CUDA) device from public API

Makefile
whisper.cpp
whisper.h

index 284b0f1e9eae6336e4c701dc886809ed3edd633d..4a676f1ff6b330dbb472bf95ce02bc7060b4f417 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -215,9 +215,9 @@ endif
 
 ifdef WHISPER_CUBLAS
        ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
-               CUDA_ARCH_FLAG=native
+               CUDA_ARCH_FLAG ?= native
        else
-               CUDA_ARCH_FLAG=all
+               CUDA_ARCH_FLAG ?= all
        endif
 
        CFLAGS      += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
index ba867b09bd07c45f4a0191ee45dd2f1857e3b2d1..59d5cff1df51393bbd78f1e750989b599e0cb28f 100644 (file)
@@ -1060,7 +1060,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 #ifdef GGML_USE_CUBLAS
     if (params.use_gpu && ggml_cublas_loaded()) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
-        backend_gpu = ggml_backend_cuda_init(0);
+        backend_gpu = ggml_backend_cuda_init(params.gpu_device);
         if (!backend_gpu) {
             WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
         }
@@ -3213,6 +3213,7 @@ int whisper_ctx_init_openvino_encoder(
 struct whisper_context_params whisper_context_default_params() {
     struct whisper_context_params result = {
         /*.use_gpu    =*/ true,
+        /*.gpu_device =*/ 0,
     };
     return result;
 }
index 3143ceaaf180b65dcd4d892c43b130f1f309c8a3..d571a125db3c48d208c6e1c33f29f6b8367cfbab 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -86,6 +86,7 @@ extern "C" {
 
     struct whisper_context_params {
         bool  use_gpu;
+        int   gpu_device;  // CUDA device
     };
 
     typedef struct whisper_token_data {