]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fix im2col overflowing maxworkgroupcount (llama/18180)
authorJeff Bolz <redacted>
Sun, 21 Dec 2025 09:32:58 +0000 (03:32 -0600)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/im2col.comp
tests/test-backend-ops.cpp

index a5308fa54426641624d1935da015ca78aab6a189..a871f85afb0856a6050b19a7d7dde3e082e5bb85 100644 (file)
@@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
     int32_t s0; int32_t s1;
     int32_t p0; int32_t p1;
     int32_t d0; int32_t d1;
+    uint32_t batch_IC;
 };
 
 struct vk_op_im2col_3d_push_constants {
@@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
         std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
     }
     std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
+    GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+                wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+                wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
     GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
     GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
     GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             const uint32_t batch = src1->ne[is_2D ? 3 : 2];
 
             elements = { OW * KW * KH, OH, batch * IC };
+            elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
         } break;
     case GGML_OP_IM2COL_3D:
         {
@@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
     const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     const uint32_t pelements = OW * KW * KH;
+    const uint32_t batch = src1->ne[is_2D ? 3 : 2];
 
     const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
     const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
         IC, IW, IH, OW, OH, KW, KH,
         pelements,
         IC * KH * KW,
-        s0, s1, p0, p1, d0, d1,
+        s0, s1, p0, p1, d0, d1, batch * IC
     });
 }
 
index 1827d647a21957e81b9895543ed64c3ba3dd94a9..db14f5a3cf3fedf792154ec67fa6f78739448815 100644 (file)
@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
     int s0; int s1;
     int p0; int p1;
     int d0; int d1;
+    uint batch_IC;
 } p;
 
 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 layout (buffer_reference) buffer D_ptr {D_TYPE d;};
 #endif
 
-void main() {
+void im2col(const uint y, const uint z) {
     const uint gidx = gl_GlobalInvocationID.x;
 
-    const uint oh = gl_GlobalInvocationID.y;
-    const uint batch = gl_GlobalInvocationID.z / p.IC;
-    const uint ic = gl_GlobalInvocationID.z % p.IC;
+    const uint oh = y;
+    const uint batch = z / p.IC;
+    const uint ic = z % p.IC;
 
     const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
     const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@@ -101,3 +102,15 @@ void main() {
 #endif
     }
 }
+
+void main() {
+    uint y = gl_GlobalInvocationID.y;
+    while (y < p.OH) {
+        uint z = gl_GlobalInvocationID.z;
+        while (z < p.batch_IC) {
+            im2col(y, z);
+            z += gl_NumWorkGroups.z;
+        }
+        y += gl_NumWorkGroups.y;
+    }
+}
index 5116918e7188a2024a1aaba91206498977257558..2cdbe66a84b2f49f3528e5f75d80bfb2414c4b88 100644 (file)
@@ -6930,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
 
     // im2col 3D
     test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));