// ne21 = n_rows
const int dst_rows = ne20*ne21;
const int dst_rows_min = n_as;
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
// max size of the rowids array in the kernel shared buffer
- GGML_ASSERT(dst_rows <= 2048);
+ GGML_ASSERT(dst_rows <= dst_rows_max);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel