#endif
#ifndef MUL_MAT_ID
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
+ uint batch_idx_a = 0;
+ if (batch_idx != 0) {
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
- const uint batch_idx_a = i03 * p.ne02 + i02;
+ batch_idx_a = i03 * p.ne02 + i02;
+ }
#else
const uint expert_id = data_ids[expert_idx];
#endif