uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+ uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
+ uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+ uint32_t base_work_group_z = 0;
+ while (base_work_group_z < batch) {
+ uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
+ base_work_group_z += groups_z;
+ }
return;
}
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
- // Make sure enough workgroups get assigned for split k to work
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+ uint32_t base_work_group_z = 0;
+ while (base_work_group_z < batch) {
+ uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
+ // Make sure enough workgroups get assigned for split k to work
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
+ base_work_group_z += groups_z;
+ }
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
}
// Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
- ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
- // compute
- const vk_mat_vec_push_constants pc = {
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, stride_batch_d,
- fusion_flags,
- (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
- };
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
- {
- d_X,
- d_Y,
- d_D,
- d_F0,
- d_F1,
- },
- pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
+
+ uint32_t base_work_group_y = 0;
+ while (base_work_group_y < ne12 * ne13) {
+
+ uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+ const vk_mat_vec_push_constants pc = {
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+ stride_batch_x, stride_batch_y, stride_batch_d,
+ fusion_flags, base_work_group_y,
+ (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
+ };
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+ {
+ d_X,
+ d_Y,
+ d_D,
+ d_F0,
+ d_F1,
+ },
+ pc, { groups_x, groups_y, groups_z });
+ base_work_group_y += groups_y;
+ }
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
- src1->ne[3] == 1) {
+ src1->ne[3] == 1 &&
+ src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+ src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
- !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
+ src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+ src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+ src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
}
}
- ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
// y[i] = i % k;
}
- ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);