]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: split mul_mat into multiple dispatches to avoid overflow (llama/19509)
authorJeff Bolz <redacted>
Wed, 18 Feb 2026 09:47:10 +0000 (01:47 -0800)
committerGeorgi Gerganov <redacted>
Fri, 27 Feb 2026 18:57:58 +0000 (20:57 +0200)
* vulkan: split mul_mat into multiple dispatches to avoid overflow

The batch dimensions can be greater than the max workgroup count limit,
in which case we need to split into multiple dispatches and pass the base
index through a push constant.

Fall back for the less common p021 and nc variants.

* address feedback

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

index 114992da08d744e1941aa5245afa17b434b77ae5..a8840a0773bacf5da83edb0a1470bebd2663cf94 100644 (file)
@@ -944,6 +944,7 @@ struct vk_mat_mat_push_constants {
     uint32_t M; uint32_t N; uint32_t K;
     uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
     uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+    uint32_t base_work_group_z; uint32_t num_batches;
     uint32_t k_split;
     uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
     uint32_t padded_N;
@@ -963,6 +964,7 @@ struct vk_mat_vec_push_constants {
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
     uint32_t fusion_flags;
+    uint32_t base_work_group_y;
     uint32_t ne02;
     uint32_t ne12;
     uint32_t broadcast2;
@@ -6773,8 +6775,16 @@ static void ggml_vk_matmul(
         uint32_t padded_n) {
         VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
     if (split_k == 1) {
-        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+        uint32_t base_work_group_z = 0;
+        while (base_work_group_z < batch) {
+            uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+            const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
+            base_work_group_z += groups_z;
+        }
         return;
     }
 
@@ -6788,9 +6798,17 @@ static void ggml_vk_matmul(
     uint32_t k_split = CEIL_DIV(k, split_k);
     k_split = ROUNDUP_POW2(k_split, 256);
 
-    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
-    // Make sure enough workgroups get assigned for split k to work
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+    uint32_t base_work_group_z = 0;
+    while (base_work_group_z < batch) {
+        uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+        const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
+        // Make sure enough workgroups get assigned for split k to work
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
+        base_work_group_z += groups_z;
+    }
     ggml_vk_sync_buffers(ctx, subctx);
     const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7186,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         }
 
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (qx_needs_dequant) {
             ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
@@ -7484,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7579,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
-    // compute
-    const vk_mat_vec_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, stride_batch_d,
-        fusion_flags,
-        (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-                              {
-                                d_X,
-                                d_Y,
-                                d_D,
-                                d_F0,
-                                d_F1,
-                              },
-                              pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+    ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
+
+    uint32_t base_work_group_y = 0;
+    while (base_work_group_y < ne12 * ne13) {
+
+        uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+        const vk_mat_vec_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            stride_batch_x, stride_batch_y, stride_batch_d,
+            fusion_flags, base_work_group_y,
+            (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+                                  {
+                                    d_X,
+                                    d_Y,
+                                    d_D,
+                                    d_F0,
+                                    d_F1,
+                                  },
+                                  pc, { groups_x, groups_y, groups_z });
+        base_work_group_y += groups_y;
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -7832,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
         src1->nb[2] <= src1->nb[1] &&
         src1->nb[1] <= src1->nb[3] &&
         src0->ne[3] == 1 &&
-        src1->ne[3] == 1) {
+        src1->ne[3] == 1 &&
+        src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+        src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
     } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
-               !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
+               !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
+               src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+               src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+               src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
     // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
     // when ne12 and ne13 are one.
@@ -11560,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
@@ -12069,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         // y[i] = i % k;
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
index 4f2c7003065f627d46729ba46a3ddab0f0dcd3ac..4aeda68c7f2dbc21755ecd9be5f7d7b150861e45 100644 (file)
@@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
     uint expert_i1;
     uint nbi1;
 #else
+    uint base_work_group_y;
     uint ne02;
     uint ne12;
     uint broadcast2;
@@ -45,9 +46,9 @@ uint expert_id;
 
 void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #ifdef MUL_MAT_ID
-    const uint expert_i0 = gl_GlobalInvocationID.y;
+    const uint expert_i0 = gl_WorkGroupID.y;
 #else
-    const uint batch_idx = gl_GlobalInvocationID.y;
+    const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
 #endif
 
 #ifndef MUL_MAT_ID
index 775e9a70f6d52213edd2092044772a2747d7182b..79344d33005b3ef0b7ceea5be6f7600865db6ef1 100644 (file)
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -139,7 +141,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -149,7 +151,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
 #ifdef COOPMAT
index b6614d2fc5999d5cf90bdb4756c904cc5aa5193a..717d124e01933cf7f44ac136260ffc219cf40545 100644 (file)
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -197,7 +199,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -215,7 +217,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -255,7 +257,7 @@ void main() {
 #else
     uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
     uint pos_b = batch_idx * p.batch_stride_b;
-    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     uint stride_a = p.stride_a / QUANT_K;