]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add `CONV_TRANSPOSE_2D` (#16542)
authorIlia Ilmer <redacted>
Fri, 17 Oct 2025 06:33:58 +0000 (02:33 -0400)
committerGitHub <redacted>
Fri, 17 Oct 2025 06:33:58 +0000 (09:33 +0300)
* initial: headers and metal-device.cpp updates

* adding conv_transpose_2d

* fix type

* fix type: int32->int64

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <redacted>
* add checks for src[0] and src[1]; add type checks

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <redacted>
* add more tests, add optimization to threading

* add dynamic memory allocation in metal

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index 866cd2da585761d14bef55bbac68e2348efe57e9..75811634227b352987f8b75f37836d52ea79f8cd 100644 (file)
@@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
+
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->type         == GGML_TYPE_F32);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_UPSCALE);
 
index 28ae2e1765146f7e450b75b27a2763f88abec2da..4d58297481813689681aebc15ef1587b3d7bf3cd 100644 (file)
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm              (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
index c3c83abe4e63e55fd214281d0196ce73a32972a9..360fbe19f0fb68131cc6725b039594c21b65a440 100644 (file)
@@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_SCALE:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
+                (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
+                op->src[1]->type == GGML_TYPE_F32 &&
+                op->type == GGML_TYPE_F32;
         case GGML_OP_CLAMP:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SQR:
index fa2d82cefb40eb366fcfc66a74c0ae250fb2754f..96f43d260a3c33dc539b4a6e16f259f3b66e2bf3 100644 (file)
@@ -514,6 +514,19 @@ typedef struct {
     uint64_t nb1;
 } ggml_metal_kargs_conv_transpose_1d;
 
+typedef struct {
+    int32_t  IC;
+    int32_t  IH;
+    int32_t  IW;
+    int32_t  KH;
+    int32_t  KW;
+    int32_t  OC;
+    int32_t  s0;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_conv_transpose_2d;
+
 typedef struct {
     uint64_t  ofs0;
     uint64_t  ofs1;
index 4f9f6bda00a7901f21b206f308e51d5a9b7c6652..7a85edbdcdb84385f4903f7d8e519988faf1e630 100644 (file)
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
             } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
+            } break;
         case GGML_OP_UPSCALE:
             {
                 n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+
+    const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+
+    const int32_t IC = op->src[1]->ne[2];
+    const int32_t IH = op->src[1]->ne[1];
+    const int32_t IW = op->src[1]->ne[0];
+
+    const int32_t KH = op->src[0]->ne[1];
+    const int32_t KW = op->src[0]->ne[0];
+
+    const int32_t OW = op->ne[0];
+    const int32_t OH = op->ne[1];
+    const int32_t OC = op->ne[2];
+
+    ggml_metal_kargs_conv_transpose_2d args = {
+        /*.IC  =*/ IC,
+        /*.IH  =*/ IH,
+        /*.IW  =*/ IW,
+        /*.KH  =*/ KH,
+        /*.KW  =*/ KW,
+        /*.OC  =*/ OC,
+        /*.s0  =*/ s0,
+        /*.nb0 =*/ nb0,
+        /*.nb1 =*/ nb1,
+        /*.nb2 =*/ nb2,
+    };
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    // Metal requires buffer size to be multiple of 16 bytes
+    const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index f352738698beb9b00dcf09ea4080be41ffad569c..0d9cb8af7c1d0aedc64c9ea626da8160578c7088 100644 (file)
@@ -71,6 +71,7 @@ int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_im2col            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_upscale           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pad               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pad_reflect_1d    (ggml_metal_op_t ctx, int idx);
index 496610b154b6d7642be0e0d151c0cc4f7af872d0..2c2f0141514ca252a86f8b02818f90620015d887 100644 (file)
@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
+
+typedef void (conv_transpose_2d_t)(
+        constant ggml_metal_kargs_conv_transpose_2d & args,
+        device const float * src0,
+        device const float * src1,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]]);
+
+template <typename T>
+kernel void kernel_conv_transpose_2d(
+        constant ggml_metal_kargs_conv_transpose_2d & args,
+        device const T * src0,
+        device const float * src1,
+        device        char * dst,
+        threadgroup float * shared_sum [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3   tpitg[[thread_position_in_threadgroup]],
+        uint3     ntg[[threads_per_threadgroup]]) {
+
+    const int64_t out_x = tgpig[0];
+    const int64_t out_y = tgpig[1];
+    const int64_t out_c = tgpig[2];
+
+    const int64_t kw = tpitg[0];
+    const int64_t kh = tpitg[1];
+
+    float v = 0.0f;
+
+    for (int64_t in_c = 0; in_c < args.IC; in_c++) {
+        int64_t in_y = out_y - kh;
+
+        if (in_y < 0 || in_y % args.s0) continue;
+
+        in_y /= args.s0;
+
+        if (in_y >= args.IH) continue;
+
+        int64_t in_x = out_x - kw;
+
+        if (in_x < 0 || in_x % args.s0) continue;
+
+        in_x /= args.s0;
+
+        if (in_x >= args.IW) continue;
+
+        const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
+        const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
+
+        v += (float)src0[kernel_idx] * src1[input_idx];
+    }
+
+    const uint tid = tpitg.y * ntg.x + tpitg.x;
+    shared_sum[tid] = v;
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tid == 0) {
+        float total = 0.0f;
+        const uint num_threads = ntg.x * ntg.y;
+        for (uint i = 0; i < num_threads; i++) {
+            total += shared_sum[i];
+        }
+
+        device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
+        dst_ptr[0] = total;
+    }
+}
+
+template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
+kernel void kernel_conv_transpose_2d<float>(
+    constant ggml_metal_kargs_conv_transpose_2d & args,
+    device const float * src0,
+    device const float * src1,
+    device        char * dst,
+    threadgroup float * shared_sum [[threadgroup(0)]],
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3   tpitg[[thread_position_in_threadgroup]],
+    uint3     ntg[[threads_per_threadgroup]]);
+
+template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
+kernel void kernel_conv_transpose_2d<half>(
+    constant ggml_metal_kargs_conv_transpose_2d & args,
+    device const half  * src0,
+    device const float * src1,
+    device        char * dst,
+    threadgroup float * shared_sum [[threadgroup(0)]],
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3   tpitg[[thread_position_in_threadgroup]],
+    uint3     ntg[[threads_per_threadgroup]]);
+
 kernel void kernel_upscale_f32(
     constant ggml_metal_kargs_upscale & args,
     device  const char * src0,
index d5c5a2a6656eec87b309f6825517828d1d2a80a4..82bb55ea0e18463517dc2e7b3e3f0bc5534fa755 100644 (file)
@@ -6989,6 +6989,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
 
     test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
+    test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
+    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
 
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));