return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_CONV_3D);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
+ case GGML_OP_CONV_3D:
+ return ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
+ op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col;
+typedef struct {
+ int32_t IW;
+ int32_t IH;
+ int32_t ID;
+ int32_t OW;
+ int32_t OH;
+ int32_t OD;
+ int32_t KW;
+ int32_t KH;
+ int32_t KD;
+ int32_t s0;
+ int32_t s1;
+ int32_t s2;
+ int32_t p0;
+ int32_t p1;
+ int32_t p2;
+ int32_t d0;
+ int32_t d1;
+ int32_t d2;
+ int32_t IC;
+ int32_t N;
+ int32_t OC;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_conv_3d;
+
typedef struct{
int32_t ne00;
uint64_t nb01;
{
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
} break;
+ case GGML_OP_CONV_3D:
+ {
+ n_fuse = ggml_metal_op_conv_3d(ctx, idx);
+ } break;
case GGML_OP_UPSCALE:
{
n_fuse = ggml_metal_op_upscale(ctx, idx);
return 1;
}
+int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ // 1. Extract standard dimensions and byte strides
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ // 2. Extract hyperparams from op_params
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(op->op_params))[1];
+ const int32_t s2 = ((const int32_t *)(op->op_params))[2];
+ const int32_t p0 = ((const int32_t *)(op->op_params))[3];
+ const int32_t p1 = ((const int32_t *)(op->op_params))[4];
+ const int32_t p2 = ((const int32_t *)(op->op_params))[5];
+ const int32_t d0 = ((const int32_t *)(op->op_params))[6];
+ const int32_t d1 = ((const int32_t *)(op->op_params))[7];
+ const int32_t d2 = ((const int32_t *)(op->op_params))[8];
+ const int32_t IC = ((const int32_t *)(op->op_params))[9];
+ const int32_t N = ((const int32_t *)(op->op_params))[10];
+ const int32_t OC = ((const int32_t *)(op->op_params))[11];
+
+ // 3. Build the parameter struct using the macro-generated variables
+ ggml_metal_kargs_conv_3d args = {
+ /*.IW =*/ (int32_t)op->src[1]->ne[0],
+ /*.IH =*/ (int32_t)op->src[1]->ne[1],
+ /*.ID =*/ (int32_t)op->src[1]->ne[2],
+ /*.OW =*/ (int32_t)op->ne[0],
+ /*.OH =*/ (int32_t)op->ne[1],
+ /*.OD =*/ (int32_t)op->ne[2],
+ /*.KW =*/ (int32_t)op->src[0]->ne[0],
+ /*.KH =*/ (int32_t)op->src[0]->ne[1],
+ /*.KD =*/ (int32_t)op->src[0]->ne[2],
+ s0, s1, s2,
+ p0, p1, p2,
+ d0, d1, d2,
+ IC, N, OC,
+ nb00, nb01, nb02, nb03, // Weight strides
+ nb10, nb11, nb12, nb13, // Input strides
+ nb0, nb1, nb2, nb3 // Output strides
+ };
+
+ // 4. Fetch the JIT pipeline
+ auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
+
+ // 5. Grid mapping
+ int nth0 = 32; // Standard SIMD width for Apple Silicon
+ int nth1 = 1;
+ int nth2 = 1;
+
+ int64_t spatial_volume = args.OW * args.OH * args.OD;
+
+ int ntg0 = (spatial_volume + nth0 - 1) / nth0;
+ int ntg1 = args.OC;
+ int ntg2 = args.N;
+
+ // 6. Bind and Dispatch via the ggml C wrapper
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
+
+ return 1;
+}
+
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
}
}
+template <typename T>
+kernel void kernel_conv_3d(
+ constant ggml_metal_kargs_conv_3d & args,
+ device const char * src0, // Weights [IC * OC, KD, KH, KW]
+ device const char * src1, // Inputs [IC * N, ID, IH, IW]
+ device char * dst, // Outputs [OC * N, OD, OH, OW]
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+
+ // 1. Un-flatten the spatial dimension from Grid X
+ int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
+
+ if (spatial_idx >= args.OW * args.OH * args.OD) {
+ return; // Thread falls outside the spatial volume
+ }
+
+ int64_t od = spatial_idx / (args.OW * args.OH);
+ int64_t oh = (spatial_idx / args.OW) % args.OH;
+ int64_t ow = spatial_idx % args.OW;
+
+ // 2. Map Y to Channels, Z to Batch
+ int64_t oc = tgpig.y;
+ int64_t batch_idx = tgpig.z;
+
+ // 3. Calculate anchor coordinates in the Input volume
+ int64_t i_w_base = ow * args.s0 - args.p0;
+ int64_t i_h_base = oh * args.s1 - args.p1;
+ int64_t i_d_base = od * args.s2 - args.p2;
+
+ float sum = 0.0f;
+
+ // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
+ for (int64_t ic = 0; ic < args.IC; ++ic) {
+
+ // ggml packs batch and channel together in the 4th dimension
+ int64_t src_cn_idx = batch_idx * args.IC + ic;
+ int64_t w_cn_idx = oc * args.IC + ic;
+
+ for (int64_t kz = 0; kz < args.KD; ++kz) {
+ int64_t id = i_d_base + kz * args.d2;
+ if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
+
+ for (int64_t ky = 0; ky < args.KH; ++ky) {
+ int64_t ih = i_h_base + ky * args.d1;
+ if (ih < 0 || ih >= args.IH) continue;
+
+ for (int64_t kx = 0; kx < args.KW; ++kx) {
+ int64_t iw = i_w_base + kx * args.d0;
+ if (iw < 0 || iw >= args.IW) continue;
+
+ // Convert multi-dimensional coordinates to flat byte offsets
+ int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
+ int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
+
+ // Dereference memory and cast weights to f32 if they were f16
+ float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
+ float i_val = *(device const float*)((device const char*)src1 + i_idx);
+
+ sum += w_val * i_val;
+ }
+ }
+ }
+ }
+
+ // 5. Write the accumulated value out to RAM
+ int64_t dst_cn_idx = batch_idx * args.OC + oc;
+ int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
+
+ *(device float*)(dst + d_idx) = sum;
+}
+
+// Explicit instantiations so the JIT compiler can find them by name
+template [[host_name("kernel_conv_3d_f32_f32")]]
+kernel void kernel_conv_3d<float>(
+ constant ggml_metal_kargs_conv_3d & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+// Explicit instantiation for f16 weights
+template [[host_name("kernel_conv_3d_f16_f32")]]
+kernel void kernel_conv_3d<half>(
+ constant ggml_metal_kargs_conv_3d & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;