From: Ouadie EL FAROUKI Date: Wed, 7 Aug 2024 10:25:36 +0000 (+0100) Subject: Updated SYCL device filtering (llama/8901) X-Git-Tag: upstream/0.0.1642~460 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=02a0b27f60c3e2d4fcf3b06ef9c14cc9ae421161;p=pkg%2Fggml%2Fsources%2Fggml Updated SYCL device filtering (llama/8901) * Updated device filter to depend on default_selector (fixes non-intel device issues) * Small related update to example/sycl Readme --- diff --git a/src/ggml-sycl/dpct/helper.hpp b/src/ggml-sycl/dpct/helper.hpp index ef4609e3..fe4a8f74 100644 --- a/src/ggml-sycl/dpct/helper.hpp +++ b/src/ggml-sycl/dpct/helper.hpp @@ -874,7 +874,7 @@ namespace dpct inline std::string get_preferred_gpu_platform_name() { std::string result; - std::string filter = "level-zero"; + std::string filter = ""; char* env = getenv("ONEAPI_DEVICE_SELECTOR"); if (env) { if (std::strstr(env, "level_zero")) { @@ -892,11 +892,24 @@ namespace dpct else { throw std::runtime_error("invalid device filter: " + std::string(env)); } + } else { + auto default_device = sycl::device(sycl::default_selector_v); + auto default_platform_name = default_device.get_platform().get_info(); + + if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) { + filter = "level-zero"; + } + else if (std::strstr(default_platform_name.c_str(), "CUDA")) { + filter = "cuda"; + } + else if (std::strstr(default_platform_name.c_str(), "HIP")) { + filter = "hip"; + } } - auto plaform_list = sycl::platform::get_platforms(); + auto platform_list = sycl::platform::get_platforms(); - for (const auto& platform : plaform_list) { + for (const auto& platform : platform_list) { auto devices = platform.get_devices(); auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) { return d.is_gpu();