row_ids only needs to hold the BN rows for the current tile.
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
- const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
+ const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
- const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
- GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
- // Split based on number of ids, to fit in shared memory
- const uint32_t nei0 = (uint32_t)src2->ne[0];
- const uint32_t nei1 = (uint32_t)src2->ne[1];
-
- GGML_ASSERT(nei0 <= 4096);
- const uint32_t split_size = std::min(nei1, 4096u / nei0);
-
- if (split_size == nei1) {
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
- } else {
- ggml_tensor src1_copy = *src1;
- ggml_tensor src2_copy = *src2;
- ggml_tensor dst_copy = *dst;
-
- for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
- const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
-
- src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
- src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
- dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
-
- src1_copy.ne[2] = n_tokens;
- src2_copy.ne[1] = n_tokens;
- dst_copy.ne[2] = n_tokens;
-
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
- // invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
- ctx->prealloc_y_last_pipeline_used = {};
- ctx->prealloc_y_last_tensor_used = nullptr;
- }
- }
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
}
}
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
+shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
}
barrier();
}
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true);
+ load_row_ids(expert_idx, true, ic);
} else {
- load_row_ids(expert_idx, false);
+ load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
_ne1++;
}
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1 && block + loadr_b < end_k) {
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
-shared u16vec4 row_ids[4096];
+shared u16vec4 row_ids[BN];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];
return B_TYPE(0.0);
}
- const u16vec4 row_idx = row_ids[row_i];
+ const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
uint dc = ic * BN + c;
if (dr < p.M && dc < _ne1) {
- uint row_i = dc;
+ uint row_i = c;
const u16vec4 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
}
return elem;
}
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
}
_ne1 += total;
iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
}
barrier();
}
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true);
+ load_row_ids(expert_idx, true, ic);
} else {
- load_row_ids(expert_idx, false);
+ load_row_ids(expert_idx, false, ic);
}
// Workgroup has no work
// test large experts*tokens
for (bool b : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));