const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
- // Nvidia prefers shared memory use to load large tiles of K
+ // Nvidia prefers shared memory use to load large tiles of K.
+ // Switch to loading from global memory when it would use too much shared memory.
// AMD prefers loading K directly from global memory
- const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
+ const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
};
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
- const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
+ const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
const uint32_t vsh_stride = MatBc / 4 * row_split;
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;