From: Salvatore Mesoraca Date: Mon, 30 Sep 2024 07:14:09 +0000 (+0200) Subject: vulkan : mul_mat: fix UB with small warps (#952) X-Git-Tag: upstream/0.0.1642~324 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d57505a2a093e13054a9f2f0e187bbbe3b2ad09e;p=pkg%2Fggml%2Fsources%2Fggml vulkan : mul_mat: fix UB with small warps (#952) When the device's warp size is less than 16, it is possible for loadstride_a (mul_mm.comp:114) and loadstride_b (mul_mm.comp:115) to be set to 0. Because they are calculated as: the workgroup size, multiplied by LOAD_VEC_* (which can be 1) and divided by 16. And the workgroup size is set to be the same as the warp/subgroup size. The loadstride_* variables are used as increments in the loops that populate the buffers used for the multiplication. When they are 0 they cause an infinite loop. But infinite loops without side-effects are UB and the values of loadstride_* are known at compile time. So, the compiler quietly optimizes all the loops away. As a consequence, the buffers are not populated and the multiplication result is just a matrix with all elements set to 0. We prevent the UB by making sure that the workgroup size will never be less than 16, even if our device has a smaller warp size (e.g. 8). Signed-off-by: Salvatore Mesoraca --- diff --git a/src/ggml-vulkan.cpp b/src/ggml-vulkan.cpp index c677a272..00ad13bb 100644 --- a/src/ggml-vulkan.cpp +++ b/src/ggml-vulkan.cpp @@ -1164,11 +1164,11 @@ static void ggml_vk_load_shaders(vk_device& device) { // mulmat std::initializer_list warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; std::initializer_list warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; + std::initializer_list warptile_s = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; std::initializer_list warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; std::initializer_list warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; + std::initializer_list warptile_mmq_s = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; std::array l_wg_denoms = {128, 128, 1 }; std::array m_wg_denoms = { 64, 64, 1 };