uint nbi1;
uint ne11;
#else
+ uint base_work_group_z;
+ uint num_batches;
uint k_split;
uint ne02;
uint ne12;
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
+ const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
#endif
#ifndef MUL_MAT_ID
- const uint batch_idx = gl_GlobalInvocationID.z;
+ const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
- const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {