int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
+ uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
+ GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+ wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+ wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_IM2COL_3D:
{
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
IC, IW, IH, OW, OH, KW, KH,
pelements,
IC * KH * KW,
- s0, s1, p0, p1, d0, d1,
+ s0, s1, p0, p1, d0, d1, batch * IC
});
}
int s0; int s1;
int p0; int p1;
int d0; int d1;
+ uint batch_IC;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
-void main() {
+void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
- const uint oh = gl_GlobalInvocationID.y;
- const uint batch = gl_GlobalInvocationID.z / p.IC;
- const uint ic = gl_GlobalInvocationID.z % p.IC;
+ const uint oh = y;
+ const uint batch = z / p.IC;
+ const uint ic = z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
#endif
}
}
+
+void main() {
+ uint y = gl_GlobalInvocationID.y;
+ while (y < p.OH) {
+ uint z = gl_GlobalInvocationID.z;
+ while (z < p.batch_IC) {
+ im2col(y, z);
+ z += gl_NumWorkGroups.z;
+ }
+ y += gl_NumWorkGroups.y;
+ }
+}