]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal: add CONV_3D (#19927)
authorRashid Ul Islam <redacted>
Mon, 23 Mar 2026 07:45:34 +0000 (13:15 +0530)
committerGitHub <redacted>
Mon, 23 Mar 2026 07:45:34 +0000 (09:45 +0200)
* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* metal:add conv_3d backend

Rebased with master and resolved conflicts.

* Resolved issues related to changes in variable names

* kernel void kernel_upscale_bilinear_f32 was missing in my branch, added back, should pass all tests now

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal

index 72ad876d5e48db3943b522c45c1d6416bc7ff58b..9162342ee98abeb390f25de840fbc4064ddc1c81 100644 (file)
@@ -1748,6 +1748,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_CONV_3D);
+
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->type         == GGML_TYPE_F32);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_UPSCALE);
 
index fd2b3ddeb5589f3d33a36d3f81f979a1558fcff4..de43f819312072fc09bfa398ea9e4fdb45002f48 100644 (file)
@@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
index 82101f4714e602d363d5b765be6007f4289bceff..14144aab0878219ec524eaf654dd579ba68a2eab 100644 (file)
@@ -1077,6 +1077,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
                 op->src[1]->type == GGML_TYPE_F32 &&
                 op->type == GGML_TYPE_F32;
+        case GGML_OP_CONV_3D:
+            return ggml_is_contiguous(op->src[0]) &&
+                   ggml_is_contiguous(op->src[1]) &&
+                   (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
+                   op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_TRI:
index 53437b23cda7bf818bccf70e8fa015d2fa653309..ea471090cd8099439239bf5d9526cf9fe89c57a5 100644 (file)
@@ -643,6 +643,42 @@ typedef struct {
     int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
 } ggml_metal_kargs_im2col;
 
+typedef struct {
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  ID;
+    int32_t  OW;
+    int32_t  OH;
+    int32_t  OD;
+    int32_t  KW;
+    int32_t  KH;
+    int32_t  KD;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  s2;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  p2;
+    int32_t  d0;
+    int32_t  d1;
+    int32_t  d2;
+    int32_t  IC;
+    int32_t  N;
+    int32_t  OC;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_conv_3d;
+
 typedef struct{
     int32_t  ne00;
     uint64_t nb01;
index c0bcad392b9440ba18f2d82c57865c0ed50f7f2d..3cda21be43e6bb95b4916978a9ffca646ad270c4 100644 (file)
@@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
             } break;
+        case GGML_OP_CONV_3D:
+            {
+                n_fuse = ggml_metal_op_conv_3d(ctx, idx);
+            } break;
         case GGML_OP_UPSCALE:
             {
                 n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    // 1. Extract standard dimensions and byte strides
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    // 2. Extract hyperparams from op_params
+    const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(op->op_params))[1];
+    const int32_t s2 = ((const int32_t *)(op->op_params))[2];
+    const int32_t p0 = ((const int32_t *)(op->op_params))[3];
+    const int32_t p1 = ((const int32_t *)(op->op_params))[4];
+    const int32_t p2 = ((const int32_t *)(op->op_params))[5];
+    const int32_t d0 = ((const int32_t *)(op->op_params))[6];
+    const int32_t d1 = ((const int32_t *)(op->op_params))[7];
+    const int32_t d2 = ((const int32_t *)(op->op_params))[8];
+    const int32_t IC = ((const int32_t *)(op->op_params))[9];
+    const int32_t N  = ((const int32_t *)(op->op_params))[10];
+    const int32_t OC = ((const int32_t *)(op->op_params))[11];
+
+    // 3. Build the parameter struct using the macro-generated variables
+    ggml_metal_kargs_conv_3d args = {
+        /*.IW =*/ (int32_t)op->src[1]->ne[0],
+        /*.IH =*/ (int32_t)op->src[1]->ne[1],
+        /*.ID =*/ (int32_t)op->src[1]->ne[2],
+        /*.OW =*/ (int32_t)op->ne[0],
+        /*.OH =*/ (int32_t)op->ne[1],
+        /*.OD =*/ (int32_t)op->ne[2],
+        /*.KW =*/ (int32_t)op->src[0]->ne[0],
+        /*.KH =*/ (int32_t)op->src[0]->ne[1],
+        /*.KD =*/ (int32_t)op->src[0]->ne[2],
+        s0, s1, s2,
+        p0, p1, p2,
+        d0, d1, d2,
+        IC, N, OC,
+        nb00, nb01, nb02, nb03, // Weight strides
+        nb10, nb11, nb12, nb13, // Input strides
+        nb0,  nb1,  nb2,  nb3   // Output strides
+    };
+
+    // 4. Fetch the JIT pipeline
+    auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
+
+    // 5. Grid mapping
+    int nth0 = 32; // Standard SIMD width for Apple Silicon
+    int nth1 = 1;
+    int nth2 = 1;
+
+    int64_t spatial_volume = args.OW * args.OH * args.OD;
+
+    int ntg0 = (spatial_volume + nth0 - 1) / nth0;
+    int ntg1 = args.OC;
+    int ntg2 = args.N;
+
+    // 6. Bind and Dispatch via the ggml C wrapper
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
+
+    return 1;
+}
+
 int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 019f2fec9edded7c9a74d35f27e9d5b94603e662..50e3c5c77a1af94ad667d8fe95bd1082e409f915 100644 (file)
@@ -75,6 +75,7 @@ int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_im2col            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_2d           (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_conv_3d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_upscale           (ggml_metal_op_t ctx, int idx);
index b2328605dd956177d5c297d6d254696dac2e3103..9c6b1c4f62b2a13a325113d3805fc9026290dcf3 100644 (file)
@@ -4883,6 +4883,98 @@ kernel void kernel_upscale_bilinear_f32(
     }
 }
 
+template <typename T>
+kernel void kernel_conv_3d(
+        constant ggml_metal_kargs_conv_3d & args,
+        device const  char * src0, // Weights [IC * OC, KD, KH, KW]
+        device const  char * src1, // Inputs  [IC * N,  ID, IH, IW]
+        device       char  * dst,  // Outputs [OC * N,  OD, OH, OW]
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+
+    // 1. Un-flatten the spatial dimension from Grid X
+    int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
+
+    if (spatial_idx >= args.OW * args.OH * args.OD) {
+        return; // Thread falls outside the spatial volume
+    }
+
+    int64_t od = spatial_idx / (args.OW * args.OH);
+    int64_t oh = (spatial_idx / args.OW) % args.OH;
+    int64_t ow = spatial_idx % args.OW;
+
+    // 2. Map Y to Channels, Z to Batch
+    int64_t oc = tgpig.y;
+    int64_t batch_idx = tgpig.z;
+
+    // 3. Calculate anchor coordinates in the Input volume
+    int64_t i_w_base = ow * args.s0 - args.p0;
+    int64_t i_h_base = oh * args.s1 - args.p1;
+    int64_t i_d_base = od * args.s2 - args.p2;
+
+    float sum = 0.0f;
+
+    // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
+    for (int64_t ic = 0; ic < args.IC; ++ic) {
+
+        // ggml packs batch and channel together in the 4th dimension
+        int64_t src_cn_idx = batch_idx * args.IC + ic;
+        int64_t w_cn_idx   = oc * args.IC + ic;
+
+        for (int64_t kz = 0; kz < args.KD; ++kz) {
+            int64_t id = i_d_base + kz * args.d2;
+            if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
+
+            for (int64_t ky = 0; ky < args.KH; ++ky) {
+                int64_t ih = i_h_base + ky * args.d1;
+                if (ih < 0 || ih >= args.IH) continue;
+
+                for (int64_t kx = 0; kx < args.KW; ++kx) {
+                    int64_t iw = i_w_base + kx * args.d0;
+                    if (iw < 0 || iw >= args.IW) continue;
+
+                    // Convert multi-dimensional coordinates to flat byte offsets
+                    int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
+                    int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
+
+                    // Dereference memory and cast weights to f32 if they were f16
+                    float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
+                    float i_val = *(device const float*)((device const char*)src1 + i_idx);
+
+                    sum += w_val * i_val;
+                }
+            }
+        }
+    }
+
+    // 5. Write the accumulated value out to RAM
+    int64_t dst_cn_idx = batch_idx * args.OC + oc;
+    int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
+
+    *(device float*)(dst + d_idx) = sum;
+}
+
+// Explicit instantiations so the JIT compiler can find them by name
+template [[host_name("kernel_conv_3d_f32_f32")]]
+kernel void kernel_conv_3d<float>(
+    constant ggml_metal_kargs_conv_3d & args,
+    device const char * src0,
+    device const char * src1,
+    device       char  * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]]);
+
+// Explicit instantiation for f16 weights
+template [[host_name("kernel_conv_3d_f16_f32")]]
+kernel void kernel_conv_3d<half>(
+    constant ggml_metal_kargs_conv_3d & args,
+    device const char  * src0,
+    device const char * src1,
+    device       char  * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]]);
+
+
 static inline float bicubic_weight1(float x) {
     const float a = -0.75f;
     return ((a + 2) * x - (a + 3)) * x * x + 1;