]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add POOL2D and fix IM2COL (llama/9943)
authorJun Hee Yoo <redacted>
Wed, 23 Oct 2024 10:33:45 +0000 (19:33 +0900)
committerGeorgi Gerganov <redacted>
Fri, 1 Nov 2024 08:19:05 +0000 (10:19 +0200)
* add pool_2d

Signed-off-by: Junhee Yoo <redacted>
* fix im2col and add unittest for N>=1024

Signed-off-by: Junhee Yoo <redacted>
* add tests for N % 1024 != 0

Signed-off-by: Junhee Yoo <redacted>
* remove trailing whitespaces

Signed-off-by: Junhee Yoo <redacted>
* apply suggestions

Signed-off-by: Junhee Yoo <redacted>
* apply more optimization

- original IM2COL kernel + _ext with MIN()

Signed-off-by: Junhee Yoo <redacted>
* apply review: change kernel name of pool_2d

Signed-off-by: Junhee Yoo <redacted>
* apply review

Signed-off-by: Junhee Yoo <redacted>
* fix more formatting and enhance readability

Signed-off-by: Junhee Yoo <redacted>
---------

Signed-off-by: Junhee Yoo <redacted>
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index 172a0f925d316968d695160cc5dfede12374dd98..e9541441c8f547af7eb5ec4d765c357296c3c9c7 100644 (file)
@@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
 
     GGML_METAL_KERNEL_TYPE_COUNT
 };
@@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
@@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
     }
 
     [metal_library release];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_IM2COL:
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
-        case GGML_OP_POOL_2D:
             return false;
+        case GGML_OP_POOL_2D:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
             } break;
         case GGML_OP_IM2COL:
             {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
                 GGML_ASSERT(src0->type == GGML_TYPE_F16);
                 GGML_ASSERT(src1->type == GGML_TYPE_F32);
                 GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
                 const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
                 const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
-                id<MTLComputePipelineState> pipeline = nil;
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
+
+                const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
 
                 switch (dst->type) {
-                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
-                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+                    case GGML_TYPE_F32: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
+                    } break;
+                    case GGML_TYPE_F16: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
+                    } break;
                     default: GGML_ABORT("fatal error");
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
+                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
+                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+
+                if (is_gt_mttpt) {
+                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
+                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
+                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
+
+                    const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
+
+                    const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+                } else {
+                    [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                }
             } break;
         case GGML_OP_UPSCALE:
             {
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
+        case GGML_OP_POOL_2D:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
+
+                const int32_t * opts = dst->op_params;
+                enum ggml_op_pool op = opts[0];
+
+                id<MTLComputePipelineState> pipeline = nil;
+                switch (src0t) {
+                    case GGML_TYPE_F32: {
+                        switch(op) {
+                            case GGML_OP_POOL_AVG:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
+                            case GGML_OP_POOL_MAX:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
+                            default: GGML_ASSERT(false && "not implemented");
+                        }
+                    } break;
+                    default: GGML_ASSERT(false && "not implemented");
+                }
+
+                const int32_t k0 = opts[1];
+                const int32_t k1 = opts[2];
+                const int32_t s0 = opts[3];
+                const int32_t s1 = opts[4];
+                const int32_t p0 = opts[5];
+                const int32_t p1 = opts[6];
+
+                const int64_t IH = src0->ne[1];
+                const int64_t IW = src0->ne[0];
+
+                const int64_t N  = dst->ne[3];
+                const int64_t OC = dst->ne[2];
+                const int64_t OH = dst->ne[1];
+                const int64_t OW = dst->ne[0];
+
+                const int64_t parallel_elements = N * OC * OH * OW;
+                const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
+                const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
+                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
+                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
+                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
+                [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+            } break;
        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
index 2b200032394b1f041f647ec519ffc3a7fc9e7f98..71b58be1fd8a4c440879c68aa52f1b996db43f14 100644 (file)
@@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
 template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
 template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
 
+typedef void (im2col_ext_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col_ext(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
+    const int32_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+
+    const int32_t d = tgpig[0] / CHW;
+    const int32_t chw = tgpig[0] % CHW;
+    const int32_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+    const int32_t HW = tgpig[0] % KHW;
+
+    const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+    if (tpitg_0 >= N) {
+        return;
+    }
+
+    const int32_t tpitg_1 = HW / KW;
+    const int32_t tpitg_2 = HW % KW;
+
+    const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
+template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
+
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,
@@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t
 template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+
+kernel void kernel_pool_2d_max_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+
+    float res = -INFINITY;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            res = MAX(res, i_ptr[i * IW + j]);
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
+
+kernel void kernel_pool_2d_avg_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+    // const float scale = 1. / ((eh - bh) * (ew - bw));
+    const float scale = 1. / (k0 * k1);
+
+    float res = 0;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            float cur = i_ptr[i * IW + j];
+            res += cur * scale;
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}