]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : fix gpu device selection (#2728)
authorGeorgi Gerganov <redacted>
Mon, 13 Jan 2025 11:11:37 +0000 (13:11 +0200)
committerGitHub <redacted>
Mon, 13 Jan 2025 11:11:37 +0000 (13:11 +0200)
src/whisper.cpp

index f90d3c1ae8776edb7b896270c7b086fa0a6a1bb6..11077d5b6870a3d10b34a84ffbd0ab5fdb86707e 100644 (file)
@@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
 static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
     ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
 
+    ggml_backend_dev_t dev = nullptr;
+
+    int cnt = 0;
     if (params.use_gpu) {
         for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
-                WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
-                ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
-                if (!result) {
-                    WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+            ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+                if (cnt == 0 || cnt == params.gpu_device) {
+                    dev = dev_cur;
+                }
+
+                if (++cnt > params.gpu_device) {
+                    break;
                 }
-                return result;
             }
         }
     }
 
-    return nullptr;
+    if (dev == nullptr) {
+        WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
+        return nullptr;
+    }
+
+    WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
+    ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
+    if (!result) {
+        WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+    }
+
+    return result;
 }
 
 static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
 }
 
 static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
+    ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
+
     if (!params.use_gpu) {
-        return ggml_backend_cpu_buffer_type();
+        return result;
     }
 
-    // if we have a GPU device - use it
+    int cnt = 0;
     for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
         ggml_backend_dev_t dev = ggml_backend_dev_get(i);
         if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
-            WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
-            return ggml_backend_dev_buffer_type(dev);
+            if (cnt == 0 || cnt == params.gpu_device) {
+                result = ggml_backend_dev_buffer_type(dev);
+            }
+
+            if (++cnt > params.gpu_device) {
+                break;
+            }
         }
     }
 
-    return ggml_backend_cpu_buffer_type();
+    return result;
 }
 
 // load the model from a ggml file