]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan : mul_mat: fix UB with small warps (ggml/952)
authorSalvatore Mesoraca <redacted>
Mon, 30 Sep 2024 07:14:09 +0000 (09:14 +0200)
committerGeorgi Gerganov <redacted>
Thu, 3 Oct 2024 09:22:17 +0000 (12:22 +0300)
When the device's warp size is less than 16,
it is possible for loadstride_a (mul_mm.comp:114)
and loadstride_b (mul_mm.comp:115) to be set to 0.
Because they are calculated as: the workgroup size,
multiplied by LOAD_VEC_* (which can be 1) and divided by 16.
And the workgroup size is set to be the same as the
warp/subgroup size.

The loadstride_* variables are used as increments in the
loops that populate the buffers used for the multiplication.

When they are 0 they cause an infinite loop.
But infinite loops without side-effects are UB and the
values of loadstride_* are known at compile time.
So, the compiler quietly optimizes all the loops away.
As a consequence, the buffers are not populated and
the multiplication result is just a matrix with all elements
set to 0.

We prevent the UB by making sure that the workgroup size
will never be less than 16, even if our device has a
smaller warp size (e.g. 8).

Signed-off-by: Salvatore Mesoraca <redacted>
ggml/src/ggml-vulkan.cpp

index c677a27287cc0ca5d1abb6f2fcedb9dd117e9b9c..00ad13bb9567bc5509ad2c73f8096f9fe8cb527c 100644 (file)
@@ -1164,11 +1164,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
     // mulmat
     std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_s = { device->subgroup_size,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u),  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
 
     std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_mmq_m = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u),  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
 
     std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
     std::array<uint32_t, 3> m_wg_denoms = { 64,  64, 1 };