std::cerr << "Done!" << std::endl;
}
+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
+
static vk_device ggml_vk_get_device(size_t idx) {
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
- if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
- // Intel drivers don't support coopmat properly yet
- // Only RADV supports coopmat properly on AMD
+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
device->coopmat_support = false;
}
return vk_instance.devices[idx];
}
-
static void ggml_vk_print_gpu_info(size_t idx) {
GGML_ASSERT(idx < vk_instance.device_indices.size());
size_t dev_num = vk_instance.device_indices[idx];
}
}
- if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
- // Intel drivers don't support coopmat properly yet
- // Only RADV supports coopmat properly on AMD
+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
coopmat_support = false;
}
UNUSED(instance_extensions);
}
+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
+ switch (props.vendorID) {
+ case VK_VENDOR_ID_INTEL:
+ // Intel drivers don't support coopmat properly yet
+ return false;
+ case VK_VENDOR_ID_AMD:
+ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
+ // Workaround for AMD proprietary driver reporting support on all GPUs
+ const std::string name = props.deviceName;
+ return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
+ name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
+ name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
+ }
+ return true;
+ default:
+ return true;
+ }
+}
+
// checks
#ifdef GGML_VULKAN_CHECK_RESULTS