uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
+ uint32_t expert_i1;
+ uint32_t nbi1;
};
struct vk_flash_attn_push_constants {
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
-
- GGML_ASSERT(nei1 == 1);
+ const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
- ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
uint32_t stride_batch_y = ne10*ne11;
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+ stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
}
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
- // compute
- const vk_mat_vec_id_push_constants pc = {
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
- fusion_flags,
- (uint32_t)nei0, (uint32_t)ne11,
- };
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
- {
- d_X,
- d_Y,
- d_D,
- d_F0,
- d_F1,
- d_ids,
- },
- pc, { groups_x, (uint32_t)nei0, groups_z });
+ // Loop over the batch dimension
+ for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
+ const vk_mat_vec_id_push_constants pc = {
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+ (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
+ fusion_flags,
+ (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
+ };
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+ {
+ d_X,
+ d_Y,
+ d_D,
+ d_F0,
+ d_F1,
+ d_ids,
+ },
+ pc, { groups_x, (uint32_t)nei0, groups_z });
+ }
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
ggml_tensor * dst = cgraph->nodes[node_idx];
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src2 = dst->src[2];
- return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
+ return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
+ uint expert_i1;
+ uint nbi1;
#else
uint ne02;
uint ne12;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.y;
+ const uint expert_i0 = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
batch_idx_a = i03 * p.ne02 + i02;
}
#else
- expert_id = data_ids[expert_idx];
+ expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
#endif
a_offset =
#endif
b_offset =
#ifdef MUL_MAT_ID
- (expert_idx % p.ne11) * p.stride_b;
+ (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
- expert_idx * p.stride_d;
+ expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
- const uint expert_idx = gl_GlobalInvocationID.y;
- tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ const uint expert_i0 = gl_GlobalInvocationID.y;
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {