]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Set k_load_shmem to false when K is too large (llama/19301)
authorJeff Bolz <redacted>
Thu, 5 Feb 2026 07:48:33 +0000 (01:48 -0600)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
src/ggml-vulkan/ggml-vulkan.cpp

index af57685a37dda37d1d7c10fa4e496c4788ea4c66..2f6570181ad7626356483959929229e782ffa3f3 100644 (file)
@@ -3204,9 +3204,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
         const uint32_t D_lsb = D ^ (D & (D-1));
         uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
 
-        // Nvidia prefers shared memory use to load large tiles of K
+        // Nvidia prefers shared memory use to load large tiles of K.
+        // Switch to loading from global memory when it would use too much shared memory.
         // AMD prefers loading K directly from global memory
-        const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
+        const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
 
         return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
     };
@@ -8412,7 +8413,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
+    const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
     const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
     const uint32_t vsh_stride = MatBc / 4 * row_split;
     const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;