s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 128, 16, 0 };
- m_warptile_mmqid = { 256, 128, 64, 16, 0 };
- s_warptile_mmqid = { 256, 128, 64, 16, 0 };
+ l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
+ m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
+ s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
-shared uint _ne1_sh;
+shared uvec4 ballots_sh[NUM_WARPS];
+void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+ }
+ _ne1 += total;
+ iter &= 15;
+ }
+ barrier();
+}
#endif
#endif // MUL_MAT_ID
-#define NUM_WARPS (BLOCK_SIZE / WARP)
-
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#ifdef MUL_MAT_ID
#ifdef COOPMAT
- // Spread the search across all elements in the first subgroup
- if (gl_SubgroupID == 0) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
-
- uint ids[16];
- uint iter = 0;
-
- for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
- bool in_range = i < num_elements;
- uint ii1 = i / p.nei0;
- uint ii0 = i % p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_SubgroupInvocationID;
- bool in_range = i < num_elements;
- uint ii1 = i / p.nei0;
- uint ii0 = i % p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
- uint idx = subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
- }
- _ne1 += subgroupBallotBitCount(ballot);
- iter &= 15;
- }
- _ne1_sh = _ne1;
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true);
+ } else {
+ load_row_ids(expert_idx, false);
}
-
- barrier();
-
- _ne1 = _ne1_sh;
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
#endif
#include "types.comp"
+#include "utils.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
};
uint _ne1;
-shared uint _ne1_sh;
+layout (constant_id = 5) const uint subgroup_size = 32;
+shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
return elem;
}
+void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+ }
+ _ne1 += total;
+ iter &= 15;
+ }
+ barrier();
+}
#endif
void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
- // Spread the search across all elements in the first subgroup
- if (gl_SubgroupID == 0) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
-
- uint ids[16];
- uint iter = 0;
-
- for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
- bool in_range = i < num_elements;
- uint ii1 = i / p.nei0;
- uint ii0 = i % p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_SubgroupInvocationID;
- bool in_range = i < num_elements;
- uint ii1 = i / p.nei0;
- uint ii0 = i % p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
- uint idx = subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
- }
- _ne1 += subgroupBallotBitCount(ballot);
- iter &= 15;
- }
- _ne1_sh = _ne1;
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true);
+ } else {
+ load_row_ids(expert_idx, false);
}
- barrier();
-
- _ne1 = _ne1_sh;
-
// Workgroup has no work
if (ic * BN >= _ne1) return;
#endif