]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : faster ggml_conv_2d using 2-stage op (#483)
authorleejet <redacted>
Mon, 9 Oct 2023 15:18:47 +0000 (23:18 +0800)
committerGitHub <redacted>
Mon, 9 Oct 2023 15:18:47 +0000 (18:18 +0300)
* ggml : fix ggm_conv_2d impl

* ggml : make ggml_conv_2d a little faster

* ggml : reorganize ggml_conv_2d code

* ggml : make ggml_conv_2d faster

* use int64_t in conv_2d stage 0

* ggml : add TODO about im2col

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml/ggml.h
src/ggml.c

index 3eddc44b90fdd35397f24409382f7428bf03b0cc..4b16032f02a8160e10608bbfdd174db195fb98d7 100644 (file)
@@ -400,15 +400,16 @@ extern "C" {
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
         GGML_OP_CONV_1D,
-        GGML_OP_CONV_2D,
+        GGML_OP_CONV_1D_STAGE_0,  // internal
+        GGML_OP_CONV_1D_STAGE_1,  // internal
         GGML_OP_CONV_TRANSPOSE_1D,
+        GGML_OP_CONV_2D,
+        GGML_OP_CONV_2D_STAGE_0, // internal
+        GGML_OP_CONV_2D_STAGE_1, // internal
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
 
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
-
         GGML_OP_UPSCALE, // nearest interpolate
 
         GGML_OP_FLASH_ATTN,
@@ -1016,9 +1017,9 @@ extern "C" {
             struct ggml_tensor  * b,
             float                 eps);
 
-    // A: n columns, m rows
-    // B: n columns, p rows  (i.e. we transpose it internally)
-    // result is m columns, p rows
+    // A: k columns, n rows => [ne03, ne02, n, k]
+    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
+    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
     GGML_API struct ggml_tensor * ggml_mul_mat(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 6d1776ca46741f2d436a0c667e4609dd04c28a38..2d3c7b8018503d9c1c4b3ee2b0faf3dcdb17ab28 100644 (file)
@@ -571,7 +571,6 @@ int64_t ggml_cycles_per_ms(void) {
 #define ggml_perf_cycles_per_ms() 0
 #endif
 
-
 //
 // cache line
 //
@@ -1828,7 +1827,6 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     return type_traits[type];
 }
 
-
 //
 // simd mappings
 //
@@ -4057,16 +4055,17 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ALIBI",
     "CLAMP",
     "CONV_1D",
+    "CONV_1D_STAGE_0",
+    "CONV_1D_STAGE_1",
     "CONV_TRANSPOSE_1D",
     "CONV_2D",
+    "CONV_2D_STAGE_0",
+    "CONV_2D_STAGE_1",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
 
-    "CONV_1D_STAGE_0",
-    "CONV_1D_STAGE_1",
-
     "FLASH_ATTN",
     "FLASH_FF",
     "FLASH_ATTN_BACK",
@@ -4092,7 +4091,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
+static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -4143,16 +4142,17 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "alibi(x)",
     "clamp(x)",
     "conv_1d(x)",
+    "conv_1d_stage_0(x)",
+    "conv_1d_stage_1(x)",
     "conv_transpose_1d(x)",
     "conv_2d(x)",
+    "conv_2d_stage_0(x)",
+    "conv_2d_stage_1(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
 
-    "conv_1d_stage_0(x)",
-    "conv_1d_stage_1(x)",
-
     "flash_attn(x)",
     "flash_ff(x)",
     "flash_attn_back(x)",
@@ -4178,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
+static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4209,8 +4209,10 @@ static void ggml_setup_op_has_task_pass(void) {
         p[GGML_OP_CONV_1D                ] = true;
         p[GGML_OP_CONV_1D_STAGE_0        ] = true;
         p[GGML_OP_CONV_1D_STAGE_1        ] = true;
-        p[GGML_OP_CONV_2D                ] = true;
         p[GGML_OP_CONV_TRANSPOSE_1D      ] = true;
+        p[GGML_OP_CONV_2D                ] = true;
+        p[GGML_OP_CONV_2D_STAGE_0        ] = true;
+        p[GGML_OP_CONV_2D_STAGE_1        ] = true;
         p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
         p[GGML_OP_FLASH_ATTN_BACK        ] = true;
         p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
@@ -5921,7 +5923,6 @@ struct ggml_tensor * ggml_sqrt_inplace(
     return ggml_sqrt_impl(ctx, a, true);
 }
 
-
 // ggml_log
 
 static struct ggml_tensor * ggml_log_impl(
@@ -5975,7 +5976,6 @@ struct ggml_tensor * ggml_sum(
     return result;
 }
 
-
 // ggml_sum_rows
 
 struct ggml_tensor * ggml_sum_rows(
@@ -6607,7 +6607,6 @@ struct ggml_tensor * ggml_set_2d_inplace(
     return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
 }
 
-
 // ggml_cpy
 
 static struct ggml_tensor * ggml_cpy_impl(
@@ -6687,7 +6686,6 @@ struct ggml_tensor * ggml_cont_inplace(
     return ggml_cont_impl(ctx, a, true);
 }
 
-
 // make contiguous, with new shape
 GGML_API struct ggml_tensor * ggml_cont_1d(
         struct ggml_context * ctx,
@@ -7140,7 +7138,6 @@ struct ggml_tensor * ggml_diag(
     return result;
 }
 
-
 // ggml_diag_mask_inf
 
 static struct ggml_tensor * ggml_diag_mask_inf_impl(
@@ -7252,7 +7249,6 @@ struct ggml_tensor * ggml_soft_max_inplace(
     return ggml_soft_max_impl(ctx, a, true);
 }
 
-
 // ggml_soft_max_back
 
 static struct ggml_tensor * ggml_soft_max_back_impl(
@@ -7669,7 +7665,11 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
 
 // ggml_conv_2d
 
-struct ggml_tensor * ggml_conv_2d(
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OH, OW, IC*KH*KW]
+static struct ggml_tensor * ggml_conv_2d_stage_0(
     struct ggml_context * ctx,
     struct ggml_tensor  * a,
     struct ggml_tensor  * b,
@@ -7688,17 +7688,21 @@ struct ggml_tensor * ggml_conv_2d(
         is_node = true;
     }
 
+    const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
+    const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+
     const int64_t ne[4] = {
-        ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
-        ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
-        a->ne[3], b->ne[3],
+        a->ne[2] * a->ne[1] * a->ne[0],
+        OW,
+        OH,
+        b->ne[3],
     };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
 
     int32_t params[] = { s0, s1, p0, p1, d0, d1 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_CONV_2D;
+    result->op = GGML_OP_CONV_2D_STAGE_0;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
@@ -7707,8 +7711,61 @@ struct ggml_tensor * ggml_conv_2d(
 
 }
 
-// ggml_conv_2d_sk_p0
+// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
+// a: [OC, IC, KH, KW]
+// b: [N, OH, OW, IC * KH * KW]
+// result: [N, OC, OH, OW]
+static struct ggml_tensor * ggml_conv_2d_stage_1(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b) {
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int64_t ne[4] = {
+        b->ne[1],
+        b->ne[2],
+        a->ne[3],
+        b->ne[3],
+    };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op = GGML_OP_CONV_2D_STAGE_1;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+
+}
+
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OC, OH, OW]
+struct ggml_tensor * ggml_conv_2d(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b,
+    int                  s0,
+    int                  s1,
+    int                  p0,
+    int                  p1,
+    int                  d0,
+    int                  d1) {
+
+    struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
+    result = ggml_conv_2d_stage_1(ctx, a, result);
+
+    return result;
+
+}
 
+// ggml_conv_2d_sk_p0
 struct ggml_tensor * ggml_conv_2d_sk_p0(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -8147,7 +8204,6 @@ static struct ggml_tensor * ggml_add_rel_pos_impl(
     return result;
 }
 
-
 struct ggml_tensor * ggml_add_rel_pos(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -8592,8 +8648,6 @@ struct ggml_tensor * ggml_map_custom3_inplace(
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
 }
 
-
-
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(
@@ -9794,7 +9848,6 @@ static void ggml_compute_forward_add1(
     }
 }
 
-
 // ggml_compute_forward_acc
 
 static void ggml_compute_forward_acc_f32(
@@ -9934,7 +9987,6 @@ static void ggml_compute_forward_sub_f32(
             const int i2 = (ir - i3*ne2*ne1)/ne1;
             const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-
 #ifdef GGML_USE_ACCELERATE
             vDSP_vsub(
                     (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
@@ -10115,7 +10167,6 @@ static void ggml_compute_forward_div_f32(
             const int i2 = (ir - i3*ne2*ne1)/ne1;
             const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-
 #ifdef GGML_USE_ACCELERATE
             UNUSED(ggml_vec_div_f32);
 
@@ -10253,7 +10304,6 @@ static void ggml_compute_forward_sqrt(
     }
 }
 
-
 // ggml_compute_forward_log
 
 static void ggml_compute_forward_log_f32(
@@ -12086,7 +12136,6 @@ static void ggml_compute_forward_out_prod_f32(
         }
     }
 
-
     //int64_t t1 = ggml_perf_time_us();
     //static int64_t acc = 0;
     //acc += t1 - t0;
@@ -12282,7 +12331,6 @@ static void ggml_compute_forward_scale_f32(
 
     const size_t nb1 = dst->nb[1];
 
-
     for (int i1 = ir0; i1 < ir1; i1++) {
         if (dst->data != src0->data) {
             // src0 is same shape as dst => same indices
@@ -12680,7 +12728,6 @@ static void ggml_compute_forward_get_rows_back_f32(
     }
 }
 
-
 static void ggml_compute_forward_get_rows_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -13966,6 +14013,7 @@ static void ggml_compute_forward_conv_1d_f32(
     }
 }
 
+// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1
 static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
                              ggml_fp16_t * A,
                              ggml_fp16_t * B,
@@ -14419,6 +14467,144 @@ static void ggml_compute_forward_conv_transpose_1d(
 
 // ggml_compute_forward_conv_2d
 
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_conv_2d_stage_0_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int64_t N = ne13;
+    const int64_t IC = ne12;
+    const int64_t IH = ne11;
+    const int64_t IW = ne10;
+
+    // const int64_t OC = ne03;
+    // const int64_t IC = ne02;
+    const int64_t KH = ne01;
+    const int64_t KW = ne00;
+
+    const int64_t OH = ne2;
+    const int64_t OW = ne1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        memset(dst->data, 0, ggml_nbytes(dst));
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) {
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic+=nth) {
+
+                        // micro kernel
+                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
+// src0: [OC, IC, KH, KW]
+// src1: [N, OH, OW, IC * KH * KW]
+// result: [N, OC, OH, OW]
+static void ggml_compute_forward_conv_2d_stage_1_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    const int N = ne13;
+    const int OH = ne12;
+    const int OW = ne11;
+
+    const int OC = ne03;
+    const int IC = ne02;
+    const int KH = ne01;
+    const int KW = ne00;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t m = OC;
+    int64_t n = OH * OW;
+    int64_t k = IC * KH * KW;
+
+    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
+    for (int i = 0; i < N; i++) {
+        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
+        ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
+        float * C = (float *)dst->data + i * m * n; // [m, n]
+
+        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
+    }
+}
+
 static void ggml_compute_forward_conv_2d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -14433,14 +14619,38 @@ static void ggml_compute_forward_conv_2d_f16_f32(
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    // src1: image [N, IC, IH, IW]
+    // src0: kernel [OC, IC, KH, KW]
+    // dst:  result [N, OC, OH, OW]
+    // ne12: IC
+    // ne0: OW
+    // ne1: OH
+    // nk0: KW
+    // nk1: KH
+    // ne13: N
+
+    const int N = ne13;
+    const int IC = ne12;
+    const int IH = ne11;
+    const int IW = ne10;
+
+    const int OC = ne03;
+    // const int IC = ne02;
+    const int KH = ne01;
+    const int KW = ne00;
+
+    const int OH = ne1;
+    const int OW = ne0;
+
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nk0 = ne00;
-    const int nk1 = ne01;
+    // const int nk0 = ne00;
+    // const int nk1 = ne01;
 
     // size of the convolution row - the kernel size unrolled across all channels
-    const int ew0 = nk0*nk1*ne02;
+    // const int ew0 = nk0*nk1*ne02;
+    // ew0: IC*KH*KW
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
     const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -14456,24 +14666,27 @@ static void ggml_compute_forward_conv_2d_f16_f32(
         memset(params->wdata, 0, params->wsize);
 
         // prepare source data (src1)
+        // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
+
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i13 = 0; i13 < ne13; i13++) {
-                for (int i12 = 0; i12 < ne12; i12++) {
-                    const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
-                    ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
-
-                    for (int i1 = 0; i1 < ne1; i1++) {
-                        for (int i0 = 0; i0 < ne0; i0++) {
-                            for (int ik1 = 0; ik1 < nk1; ik1++) {
-                                for (int ik0 = 0; ik0 < nk0; ik0++) {
-                                    const int idx0 = i0*s0 + ik0*d0 - p0;
-                                    const int idx1 = i1*s1 + ik1*d1 - p1;
-
-                                    if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
-                                        dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
-                                            GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
+            for (int in = 0; in < N; in++) {
+                for (int iic = 0; iic < IC; iic++) {
+                    for (int ioh = 0; ioh < OH; ioh++) {
+                        for (int iow = 0; iow < OW; iow++) {
+
+                            // micro kernel
+                            ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                            const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
+
+                            for (int ikh = 0; ikh < KH; ikh++) {
+                                for (int ikw = 0; ikw < KW; ikw++) {
+                                    const int iiw = iow*s0 + ikw*d0 - p0;
+                                    const int iih = ioh*s1 + ikh*d1 - p1;
+
+                                    if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
+                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
                                     }
                                 }
                             }
@@ -14490,30 +14703,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
         return;
     }
 
-    // total patches in dst
-    const int np = ne2;
-
-    // patches per thread
-    const int dp = (np + nth - 1)/nth;
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    // wdata: [N*OH*OW, IC*KH*KW]
+    // dst: result [N, OC, OH, OW]
+    // src0: kernel [OC, IC, KH, KW]
 
-    // patch range for this thread
-    const int ip0 = dp*ith;
-    const int ip1 = MIN(ip0 + dp, np);
+    int64_t m = OC;
+    int64_t n = OH * OW;
+    int64_t k = IC * KH * KW;
 
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
+    for (int i = 0; i < N; i++) {
+        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
+        ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
+        float * C = (float *)dst->data + i * m * n; // [m * k]
 
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ip0; i2 < ip1; i2++) {
-            float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
-
-            for (int i1 = 0; i1 < ne1; ++i1) {
-                for (int i0 = 0; i0 < ne0; ++i0) {
-                    ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
-                            (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
-                            (ggml_fp16_t *)                wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
-                }
-            }
-        }
+        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
     }
 }
 
@@ -14539,6 +14744,48 @@ static void ggml_compute_forward_conv_2d(
     }
 }
 
+static void ggml_compute_forward_conv_2d_stage_0(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(false);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_compute_forward_conv_2d_stage_1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(false);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_conv_transpose_2d
 
 static void ggml_compute_forward_conv_transpose_2d(
@@ -16095,7 +16342,6 @@ static void ggml_compute_forward_add_rel_pos_f32(
     const int ip0 = dp*ith;
     const int ip1 = MIN(ip0 + dp, np);
 
-
     for (int64_t i13 = ip0; i13 < ip1; ++i13) {
         for (int64_t i12 = 0; i12 < ne12; ++i12) {
             for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -16162,7 +16408,6 @@ static void ggml_compute_forward_map_unary_f32(
     }
 }
 
-
 static void ggml_compute_forward_map_unary(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -16210,7 +16455,6 @@ static void ggml_compute_forward_map_binary_f32(
     }
 }
 
-
 static void ggml_compute_forward_map_binary(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -16262,7 +16506,6 @@ static void ggml_compute_forward_map_custom2_f32(
     fun(dst, a, b);
 }
 
-
 // ggml_compute_forward_map_custom3
 
 static void ggml_compute_forward_map_custom3_f32(
@@ -16537,7 +16780,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
         ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
 
-
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             assert(!isnan(ds0[i]));
@@ -16565,7 +16807,6 @@ static void ggml_compute_forward_cross_entropy_loss_back(
     }
 }
 
-
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -16773,6 +17014,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_CONV_2D_STAGE_0:
+            {
+                ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
+            } break;
+        case GGML_OP_CONV_2D_STAGE_1:
+            {
+                ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
+            } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17702,11 +17951,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_CONV_2D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
+        case GGML_OP_CONV_2D_STAGE_0:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CONV_2D_STAGE_1:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
@@ -18635,6 +18892,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     const int64_t ne0 = node->ne[0];
                     const int64_t ne1 = node->ne[1];
                     const int64_t ne2 = node->ne[2];
+                    const int64_t ne3 = node->ne[3];
                     const int64_t nk = ne00*ne01;
                     const int64_t ew0 = nk * ne02;
 
@@ -18645,7 +18903,8 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
                     if (node->src[0]->type == GGML_TYPE_F16 &&
                         node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
+                        // im2col: [N*OH*OW, IC*KH*KW]
+                        cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
                     } else if (node->src[0]->type == GGML_TYPE_F32 &&
                                node->src[1]->type == GGML_TYPE_F32) {
                         cur = sizeof(float)*      (ne10*ne11*ne12);
@@ -18655,6 +18914,14 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
                     work_size = MAX(work_size, cur);
                 } break;
+            case GGML_OP_CONV_2D_STAGE_0:
+                {
+                    n_tasks = n_threads;
+                } break;
+            case GGML_OP_CONV_2D_STAGE_1:
+                {
+                    n_tasks = n_threads;
+                } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
                     n_tasks = n_threads;
@@ -19842,7 +20109,6 @@ static enum ggml_opt_result ggml_opt_adam(
 
         opt->loss_after = fx;
 
-
         // check convergence
         if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
             GGML_PRINT_DEBUG("converged\n");