]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : restore im2col perf (llama/16219)
authorGeorgi Gerganov <redacted>
Thu, 25 Sep 2025 08:29:08 +0000 (11:29 +0300)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:10 +0000 (15:18 +0300)
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 9f91662cbd876d447b02eae2e968d6178748e6b5..8aeefd2c6830f14eea7d42f51815dd052dc6348d 100644 (file)
@@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
+    snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
index 3b163d9a38e75afbc9ab324650181264894401e5..15cea725135365496cd082cd3715b3452bb2d66b 100644 (file)
@@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
     const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
     const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
 
-
     ggml_metal_kargs_im2col args = {
         /*.ofs0 =*/ ofs0,
         /*.ofs1 =*/ ofs1,
@@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
 
-    const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
-    const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
+    GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
 
     return 1;
 }
index 2ba4cb50b9ec224121caa50410f599bc69f6ea18..339cbf91fbe5797ac140fa304fd4f13a6da2c41e 100644 (file)
@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
 template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
 template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
 
+typedef void (im2col_t)(
+        constant ggml_metal_kargs_im2col & args,
+        device const float * x,
+        device        char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col(
+        constant ggml_metal_kargs_im2col & args,
+        device const float * x,
+        device        char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+//    const int64_t IC = tgpg[0];
+    const int64_t OH = tgpg[1];
+    const int64_t OW = tgpg[2];
+
+    const int64_t KH = ntg[1];
+    const int64_t KW = ntg[2];
+
+          int64_t in  = tpitg[0];
+    const int64_t ikh = tpitg[1];
+    const int64_t ikw = tpitg[2];
+
+    const int64_t iic = tgpig[0];
+    const int64_t ioh = tgpig[1];
+    const int64_t iow = tgpig[2];
+
+    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
+    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
+
+    int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
+        while (in < args.N) {
+            pdst[offset_dst] = 0.0f;
+            offset_dst += ntg[0]*args.CHW*OH*OW;
+
+            in += ntg[0];
+        }
+    } else {
+        int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
+
+        while (in < args.N) {
+            pdst[offset_dst] = x[offset_src];
+
+            offset_dst += ntg[0]*args.CHW*OH*OW;
+            offset_src += ntg[0]*args.ofs0;
+
+            in += ntg[0];
+        }
+    }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
+
 // TODO: obolete -- remove
-//typedef void (im2col_t)(
+//typedef void (im2col_ext_t)(
 //        constant ggml_metal_kargs_im2col & args,
 //        device const float * x,
 //        device        char * dst,
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
 //        uint3   ntg[[threads_per_threadgroup]]);
 //
 //template <typename T>
-//kernel void kernel_im2col(
+//kernel void kernel_im2col_ext(
 //        constant ggml_metal_kargs_im2col & args,
 //        device const float * x,
 //        device        char * dst,
 //        uint3 tgpig[[threadgroup_position_in_grid]],
-//        uint3  tgpg[[threadgroups_per_grid]],
+//        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
 //        uint3 tpitg[[thread_position_in_threadgroup]],
-//        uint3   ntg[[threads_per_threadgroup]]) {
-////    const int64_t IC = tgpg[0];
-//    const int64_t OH = tgpg[1];
-//    const int64_t OW = tgpg[2];
+//        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
+//    const int64_t KHW = (int64_t)args.KHW;
 //
-////    const int64_t N  = ntg[0];
-//    const int64_t KH = ntg[1];
-//    const int64_t KW = ntg[2];
+//    const int64_t d   = tgpig[0] / args.CHW;
+//    const int64_t chw = tgpig[0] % args.CHW;
+//    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+//    const int64_t HW = tgpig[0] % KHW;
 //
-//    const int64_t in  = tpitg[0];
-//    const int64_t ikh = tpitg[1];
-//    const int64_t ikw = tpitg[2];
+//    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+//    if (tpitg_0 >= args.N) {
+//        return;
+//    }
 //
-//    const int64_t iic = tgpig[0];
-//    const int64_t ioh = tgpig[1];
-//    const int64_t iow = tgpig[2];
+//    const int64_t tpitg_1 = HW / args.KW;
+//    const int64_t tpitg_2 = HW % args.KW;
 //
-//    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
-//    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
+//    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
+//    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
 //
-//    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
+//    const int64_t offset_dst =
+//        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
+//        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
 //
 //    device T * pdst = (device T *) (dst);
 //
 //    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
 //        pdst[offset_dst] = 0.0f;
 //    } else {
-//        const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
-//        pdst[offset_dst] = x[offset_src];
+//        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
+//        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
 //    }
 //}
 //
-//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
-//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
-
-typedef void (im2col_ext_t)(
-        constant ggml_metal_kargs_im2col & args,
-        device const float * x,
-        device        char * dst,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]);
-
-template <typename T>
-kernel void kernel_im2col_ext(
-        constant ggml_metal_kargs_im2col & args,
-        device const float * x,
-        device        char * dst,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int64_t KHW = (int64_t)args.KHW;
-
-    const int64_t d   = tgpig[0] / args.CHW;
-    const int64_t chw = tgpig[0] % args.CHW;
-    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
-    const int64_t HW = tgpig[0] % KHW;
-
-    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
-    if (tpitg_0 >= args.N) {
-        return;
-    }
-
-    const int64_t tpitg_1 = HW / args.KW;
-    const int64_t tpitg_2 = HW % args.KW;
-
-    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
-    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
-
-    const int64_t offset_dst =
-        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
-        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
-
-    device T * pdst = (device T *) (dst);
-
-    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
-        pdst[offset_dst] = 0.0f;
-    } else {
-        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
-        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
-    }
-}
-
-template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
-template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
+//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
+//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
 
 typedef void (conv_transpose_1d_t)(
         constant ggml_metal_kargs_conv_transpose_1d & args,