static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
+ ggml_backend_dev_t dev = nullptr;
+
+ int cnt = 0;
if (params.use_gpu) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
- WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
- ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
- if (!result) {
- WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
+ if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+ if (cnt == 0 || cnt == params.gpu_device) {
+ dev = dev_cur;
+ }
+
+ if (++cnt > params.gpu_device) {
+ break;
}
- return result;
}
}
}
- return nullptr;
+ if (dev == nullptr) {
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
+ return nullptr;
+ }
+
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
+ if (!result) {
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+ }
+
+ return result;
}
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
}
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
+ ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
+
if (!params.use_gpu) {
- return ggml_backend_cpu_buffer_type();
+ return result;
}
- // if we have a GPU device - use it
+ int cnt = 0;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
- WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
- return ggml_backend_dev_buffer_type(dev);
+ if (cnt == 0 || cnt == params.gpu_device) {
+ result = ggml_backend_dev_buffer_type(dev);
+ }
+
+ if (++cnt > params.gpu_device) {
+ break;
+ }
}
}
- return ggml_backend_cpu_buffer_type();
+ return result;
}
// load the model from a ggml file