From: Jeff Bolz Date: Sun, 21 Dec 2025 09:32:58 +0000 (-0600) Subject: vulkan: fix im2col overflowing maxworkgroupcount (llama/18180) X-Git-Tag: upstream/1.8.3~105 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=b893e0813abd9d352647c58a29c8d8f2eb8c8ca6;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: fix im2col overflowing maxworkgroupcount (llama/18180) --- diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a5308fa5..a871f85a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants { int32_t s0; int32_t s1; int32_t p0; int32_t p1; int32_t d0; int32_t d1; + uint32_t batch_IC; }; struct vk_op_im2col_3d_push_constants { @@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]); GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); @@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; elements = { OW * KW * KH, OH, batch * IC }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; case GGML_OP_IM2COL_3D: { @@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 const uint32_t pelements = OW * KW * KH; + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; const vk_buffer d_buf = d_buf_ctx->dev_buffer; @@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co IC, IW, IH, OW, OH, KW, KH, pelements, IC * KH * KW, - s0, s1, p0, p1, d0, d1, + s0, s1, p0, p1, d0, d1, batch * IC }); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 1827d647..db14f5a3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter int s0; int s1; int p0; int p1; int d0; int d1; + uint batch_IC; } p; layout(constant_id = 0) const uint BLOCK_SIZE = 32; @@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void main() { +void im2col(const uint y, const uint z) { const uint gidx = gl_GlobalInvocationID.x; - const uint oh = gl_GlobalInvocationID.y; - const uint batch = gl_GlobalInvocationID.z / p.IC; - const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint oh = y; + const uint batch = z / p.IC; + const uint ic = z % p.IC; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); @@ -101,3 +102,15 @@ void main() { #endif } } + +void main() { + uint y = gl_GlobalInvocationID.y; + while (y < p.OH) { + uint z = gl_GlobalInvocationID.z; + while (z < p.batch_IC) { + im2col(y, z); + z += gl_NumWorkGroups.z; + } + y += gl_NumWorkGroups.y; + } +}