return res;
}
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_CONV_2D);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+ if (res) {
+ return res;
+ }
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+ return res;
+}
+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+ case GGML_OP_CONV_2D:
+ return ggml_is_contiguous(op->src[0]) &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32 &&
+ (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
+typedef struct {
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ int32_t IW;
+ int32_t IH;
+ int32_t KW;
+ int32_t KH;
+ int32_t IC;
+ int32_t OC;
+ int32_t OW;
+ int32_t OH;
+ int32_t N;
+ int32_t s0;
+ int32_t s1;
+ int32_t p0;
+ int32_t p1;
+ int32_t d0;
+ int32_t d1;
+} ggml_metal_kargs_conv_2d;
+
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
#include <cassert>
#include <algorithm>
+#include <limits>
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
if (!t) {
{
n_fuse = ggml_metal_op_im2col(ctx, idx);
} break;
+ case GGML_OP_CONV_2D:
+ {
+ n_fuse = ggml_metal_op_conv_2d(ctx, idx);
+ } break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
nth = std::min(nth, nk0);
- if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
- nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
- nrptg = 1;
- }
-
ggml_metal_kargs_set_rows args = {
/*.nk0 =*/ nk0,
/*.ne01 =*/ ne01,
return 1;
}
+int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
+
+ ggml_metal_kargs_conv_2d args = {
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.IW =*/ ne10,
+ /*.IH =*/ ne11,
+ /*.KW =*/ ne00,
+ /*.KH =*/ ne01,
+ /*.IC =*/ ne02,
+ /*.OC =*/ ne03,
+ /*.OW =*/ ne0,
+ /*.OH =*/ ne1,
+ /*.N =*/ ne3,
+ /*.s0 =*/ s0,
+ /*.s1 =*/ s1,
+ /*.p0 =*/ p0,
+ /*.p1 =*/ p1,
+ /*.d0 =*/ d0,
+ /*.d1 =*/ d1,
+ };
+
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
+
+ int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+ nth = std::min(nth, 256);
+ nth = std::max(nth, 1);
+
+ const uint64_t n_out = ggml_nelements(op);
+
+ uint64_t tg = (n_out + nth - 1)/nth;
+ tg = std::max<uint64_t>(tg, 1);
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
+template <typename TK>
+kernel void kernel_conv_2d(
+ constant ggml_metal_kargs_conv_2d & args,
+ device const char * weights,
+ device const char * src,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
+ const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
+ const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
+ const uint thread_index = tg_index * threads_per_tg + local_thread;
+ const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
+ const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
+
+ for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
+ uint64_t tmp = index;
+
+ const int32_t ow = tmp % args.OW; tmp /= args.OW;
+ const int32_t oh = tmp % args.OH; tmp /= args.OH;
+ const int32_t oc = tmp % args.OC; tmp /= args.OC;
+ const int32_t n = tmp;
+
+ float acc = 0.0f;
+
+ const int32_t base_x = ow*args.s0 - args.p0;
+ const int32_t base_y = oh*args.s1 - args.p1;
+
+ int32_t ky_start = 0;
+ if (base_y < 0) {
+ ky_start = (-base_y + args.d1 - 1)/args.d1;
+ }
+ int32_t ky_end = args.KH;
+ const int32_t y_max = args.IH - 1 - base_y;
+ if (y_max < 0) {
+ ky_end = ky_start;
+ } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
+ ky_end = min(ky_end, y_max/args.d1 + 1);
+ }
+
+ int32_t kx_start = 0;
+ if (base_x < 0) {
+ kx_start = (-base_x + args.d0 - 1)/args.d0;
+ }
+ int32_t kx_end = args.KW;
+ const int32_t x_max = args.IW - 1 - base_x;
+ if (x_max < 0) {
+ kx_end = kx_start;
+ } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
+ kx_end = min(kx_end, x_max/args.d0 + 1);
+ }
+
+ if (ky_start < ky_end && kx_start < kx_end) {
+ const uint64_t src_base_n = (uint64_t) n * args.nb13;
+ const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
+
+ for (int32_t ic = 0; ic < args.IC; ++ic) {
+ const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
+ const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
+
+ for (int32_t ky = ky_start; ky < ky_end; ++ky) {
+ const int32_t iy = base_y + ky*args.d1;
+ const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
+ const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
+
+ for (int32_t kx = kx_start; kx < kx_end; ++kx) {
+ const int32_t ix = base_x + kx*args.d0;
+ const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
+ const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
+
+ const float x = *(device const float *)(src + src_offs);
+ const float w = (float) (*(device const TK *)(weights + w_offs));
+
+ acc += x * w;
+ }
+ }
+ }
+ }
+
+ const uint64_t dst_offs =
+ (uint64_t) n * args.nb3 +
+ (uint64_t) oc * args.nb2 +
+ (uint64_t) oh * args.nb1 +
+ (uint64_t) ow * args.nb0;
+
+ *(device float *)(dst + dst_offs) = acc;
+ }
+}
+
+template [[host_name("kernel_conv_2d_f32_f32")]]
+kernel void kernel_conv_2d<float>(
+ constant ggml_metal_kargs_conv_2d & args,
+ device const char * weights,
+ device const char * src,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template [[host_name("kernel_conv_2d_f16_f32")]]
+kernel void kernel_conv_2d<half>(
+ constant ggml_metal_kargs_conv_2d & args,
+ device const char * weights,
+ device const char * src,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,