const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
- const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
+ const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
- GGML_ASSERT(nei0 * nei1 <= 3072);
+ GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
-shared u16vec4 row_ids[3072];
+shared u16vec4 row_ids[4096];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];