]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : replace conv 1D - 2D stage_0 and stage_1 with im2col and mul_mat (#564)
authorSteward Garcia <redacted>
Sun, 12 Nov 2023 13:34:04 +0000 (08:34 -0500)
committerGitHub <redacted>
Sun, 12 Nov 2023 13:34:04 +0000 (15:34 +0200)
* added conv2d stage 0 - 1 cuda kernels

* add im2col + refactor conv1d and conv2d

* fix params invalid index

* add conv1d and conv2d unit tests

* resolving wrong values and fix mul_mat validation

* improve tests + reduce code duplication

* add cuda kernels

* more data test

* fix ggml_op_count to 70

* add temp test - gemm != mul_mat

* tests : fix test-mul-mat matrix multiplication

* test-mul-mat match gemm == ggml_mul_mat with conv2d op

* replaced gemm by ggml_mul_mat

* ggml_mul_mat cpu backend support fp16 src1

* ggml_mul_mat cuda backend fp16 fixed

* remove unnecessary ggml_cont and removed conv1d-2d functions deprecated

* some fixes

* explain conv1d reshapes

* ggml : fix tests on Arm + do not use BLAS for F16 data

* tests : fix FP16 handling on Arm

* ggml : avoid ggml_cont and ggml_transpose in ggml_conv_xd

* ci : switch back to release

* cuda : fix wrong pointer usage

* ggml : add metal support for im2col and f16xf16 mul mat

* ggml : im2col opts

* Update src/ggml-cuda.cu

Co-authored-by: slaren <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c
tests/CMakeLists.txt
tests/test-conv1d.cpp [new file with mode: 0644]
tests/test-conv2d.cpp [new file with mode: 0644]
tests/test-mul-mat.cpp [new file with mode: 0644]

index e56a8337a507747eae914cd355c08d791c68ba4c..52ae6755a89ad2e8746642233fea80a1d149937a 100644 (file)
@@ -403,13 +403,8 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D,
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
         GGML_OP_CONV_TRANSPOSE_1D,
-        GGML_OP_CONV_2D,
-        GGML_OP_CONV_2D_STAGE_0, // internal
-        GGML_OP_CONV_2D_STAGE_1, // internal
+        GGML_OP_IM2COL,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
@@ -1398,6 +1393,18 @@ extern "C" {
             float                 min,
             float                 max);
 
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D);
+
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index adc34aab76f85e530b5c028fc23fc660d5d96fd3..ce4feeece7f631c5b07cb1c9b68de55e6397454c 100644 (file)
@@ -39,6 +39,7 @@
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceGetMemPool hipDeviceGetMemPool
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
@@ -48,6 +49,7 @@
 #define cudaEvent_t hipEvent_t
 #define cudaEventDestroy hipEventDestroy
 #define cudaFree hipFree
+#define cudaFreeAsync hipFreeAsync
 #define cudaFreeHost hipHostFree
 #define cudaGetDevice hipGetDevice
 #define cudaGetDeviceCount hipGetDeviceCount
@@ -55,6 +57,7 @@
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
 #define cudaMalloc hipMalloc
+#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -63,6 +66,9 @@
 #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
 #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
 #define cudaMemcpyKind hipMemcpyKind
+#define cudaMemPool_t hipMemPool_t
+#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
+#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
 #define cudaMemset hipMemset
 #define cudaMemsetAsync hipMemsetAsync
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
@@ -4470,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     *dsti = __float2half(*xi);
 }
 
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
                                    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4723,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
+static  __global__ void im2col_f32_f16(
+        const float * x, half * dst,
+        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
+    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+
+    const int offset_dst =
+        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
+        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = __float2half(0.0f);
+    } else {
+        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+    }
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5612,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5695,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 
+static void im2col_f32_f16_cuda(const float * x, half * dst,
+    int OH, int IW, int IH, int OW, int IC,
+    int KH, int KW, int N,  int ofs0, int ofs1,
+    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
+    dim3 block_nums(IC, OH, OW);
+    dim3 block_dims(N,  KH, KW);
+    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 256
 
@@ -6477,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
         size_t dst_f16_as = 0;
         half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
 
@@ -6653,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_im2col(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t N  = src1->ne[is_2D ? 3 : 2];
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
+        OH, IW, IH, OW, IC, KH, KW, N,
+        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+
+    (void) src0;
+    (void) src0_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7543,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
                               ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7574,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
+void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -7937,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ALIBI:
             func = ggml_cuda_alibi;
             break;
+        case GGML_OP_IM2COL:
+            func = ggml_cuda_im2col;
+            break;
         default:
             return false;
     }
index 43d0dff09ecb627638318789fc003da21ecee0a0..148c12b14f8f211de78fb8d0eff109c18b2e2877 100644 (file)
@@ -86,6 +86,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(im2col_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(im2col_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
+    GGML_METAL_DEL_KERNEL(im2col_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -1030,7 +1036,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1139,6 +1145,7 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                         } break;
@@ -1146,13 +1153,18 @@ void ggml_metal_graph_compute(
                                         {
                                             nth0 = 32;
                                             nth1 = 1;
-                                            if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
-                                                nrows = ne11;
+                                            if (src1t == GGML_TYPE_F32) {
+                                                if (ne11 * ne12 < 4) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+                                                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+                                                    nrows = ne11;
+                                                } else {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                    nrows = 4;
+                                                }
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1342,7 +1354,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1361,7 +1373,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_IM2COL:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                            GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+                            const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                            const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                            const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                            const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                            const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                            const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+                            const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+                            const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                            const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                            const int32_t IH = is_2D ? src1->ne[1] : 1;
+                            const int32_t IW =         src1->ne[0];
+
+                            const int32_t KH = is_2D ? src0->ne[1] : 1;
+                            const int32_t KW =         src0->ne[0];
+
+                            const int32_t OH = is_2D ? dst->ne[2] : 1;
+                            const int32_t OW =         dst->ne[1];
+
+                            const int32_t CHW = IC * KH * KW;
+
+                            const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                            const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
+                            [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
+                            [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
+                            [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
+                            [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
+                            [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
+                            [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
+                            [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
+                            [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
+                            [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
index 7c35f23a7612fd75362457b3fc9d137cd37e0bfa..5d1357cd72d4592782802a60222e81d2cacb8d8f 100644 (file)
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F16;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
+            device const half4 * y4 = (device const half4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
+kernel void kernel_im2col_f16(
+        device const float * x,
+        device       half * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
index 018f0ce0bba8ece9504c926820c299d0b3e00055..584ee4680378d9ddb3722d61e618d9837a07a88d 100644 (file)
@@ -143,12 +143,6 @@ void ggml_print_backtrace(void) {
 }
 #endif
 
-#undef MIN
-#undef MAX
-
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
@@ -277,6 +271,12 @@ inline static void * ggml_aligned_malloc(size_t size) {
 // floating point type used to accumulate sums
 typedef double ggml_float;
 
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
 //
 // global data
 //
@@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ROPE_BACK",
     "ALIBI",
     "CLAMP",
-    "CONV_1D",
-    "CONV_1D_STAGE_0",
-    "CONV_1D_STAGE_1",
     "CONV_TRANSPOSE_1D",
-    "CONV_2D",
-    "CONV_2D_STAGE_0",
-    "CONV_2D_STAGE_1",
+    "IM2COL",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
@@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope_back(x)",
     "alibi(x)",
     "clamp(x)",
-    "conv_1d(x)",
-    "conv_1d_stage_0(x)",
-    "conv_1d_stage_1(x)",
     "conv_transpose_1d(x)",
-    "conv_2d(x)",
-    "conv_2d_stage_0(x)",
-    "conv_2d_stage_1(x)",
+    "im2col(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
@@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) {
         p[GGML_OP_GET_ROWS_BACK          ] = true;
         p[GGML_OP_DIAG_MASK_INF          ] = true;
         p[GGML_OP_DIAG_MASK_ZERO         ] = true;
-        p[GGML_OP_CONV_1D                ] = true;
-        p[GGML_OP_CONV_1D_STAGE_0        ] = true;
-        p[GGML_OP_CONV_1D_STAGE_1        ] = true;
         p[GGML_OP_CONV_TRANSPOSE_1D      ] = true;
-        p[GGML_OP_CONV_2D                ] = true;
-        p[GGML_OP_CONV_2D_STAGE_0        ] = true;
-        p[GGML_OP_CONV_2D_STAGE_1        ] = true;
         p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
         p[GGML_OP_FLASH_ATTN_BACK        ] = true;
         p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
@@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
     return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
 }
 
-// im2col: [N, IC, IL] => [N, OL, IC*K]
-// a: [OC,IC, K]
-// b: [N, IC, IL]
-// result: [N, OL, IC*K]
-static struct ggml_tensor * ggml_conv_1d_stage_0(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int                   s0,
-    int                   p0,
-    int                   d0) {
-    GGML_ASSERT(a->ne[1] == b->ne[1]);
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
-
-    const int64_t ne[4] = {
-        a->ne[1] * a->ne[0],
-        OL,
-        b->ne[2],
-        1,
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
-
-    int32_t params[] = { s0, p0, d0 };
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op = GGML_OP_CONV_1D_STAGE_0;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-// ggml_conv_1d_stage_1
-
-// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-// a: [OC, IC, K]
-// b: [N, OL, IC * K]
-// result: [N, OC, OL]
-static struct ggml_tensor * ggml_conv_1d_stage_1(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b) {
-
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t ne[4] = {
-        b->ne[1],
-        a->ne[2],
-        b->ne[2],
-        1,
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-
-    result->op = GGML_OP_CONV_1D_STAGE_1;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-// ggml_conv_1d
-
 GGML_API struct ggml_tensor * ggml_conv_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5220,43 +5128,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
         int                   s0,
         int                   p0,
         int                   d0) {
-    struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
-    result = ggml_conv_1d_stage_1(ctx, a, result);
-    return result;
-}
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
 
-// GGML_API struct ggml_tensor * ggml_conv_1d(
-//         struct ggml_context * ctx,
-//         struct ggml_tensor  * a,
-//         struct ggml_tensor  * b,
-//         int                   s0,
-//         int                   p0,
-//         int                   d0) {
-//     GGML_ASSERT(ggml_is_matrix(b));
-//     GGML_ASSERT(a->ne[1] == b->ne[1]);
-//     bool is_node = false;
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]));                    // [OC,IC, K] => [OC, IC * K]
 
-//     if (a->grad || b->grad) {
-//         GGML_ASSERT(false); // TODO: implement backward
-//         is_node = true;
-//     }
+    result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
 
-//     const int64_t ne[4] = {
-//         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
-//         a->ne[2], 1, 1,
-//     };
-//     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
-
-//     int32_t params[] = { s0, p0, d0 };
-//     ggml_set_op_params(result, params, sizeof(params));
-
-//     result->op = GGML_OP_CONV_1D;
-//     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-//     result->src[0] = a;
-//     result->src[1] = b;
-
-//     return result;
-// }
+    return result;
+}
 
 // ggml_conv_1d_ph
 
@@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
 // result: [N, OH, OW, IC*KH*KW]
-static struct ggml_tensor * ggml_conv_2d_stage_0(
+struct ggml_tensor * ggml_im2col(
     struct ggml_context * ctx,
     struct ggml_tensor  * a,
     struct ggml_tensor  * b,
@@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
     int                  p0,
     int                  p1,
     int                  d0,
-    int                  d1) {
+    int                  d1,
+    bool                 is_2D) {
 
-    GGML_ASSERT(a->ne[2] == b->ne[2]);
+    if(is_2D) {
+        GGML_ASSERT(a->ne[2] == b->ne[2]);
+    } else {
+        GGML_ASSERT(a->ne[1] == b->ne[1]);
+    }
     bool is_node = false;
 
     if (a->grad || b->grad) {
@@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
         is_node = true;
     }
 
-    const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
-    const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
+    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
 
     const int64_t ne[4] = {
-        a->ne[2] * a->ne[1] * a->ne[0],
+        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
         OW,
-        OH,
-        b->ne[3],
+        is_2D ? OH : b->ne[2],
+        is_2D ?      b->ne[3] : 1,
     };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
 
-    int32_t params[] = { s0, s1, p0, p1, d0, d1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_CONV_2D_STAGE_0;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-
-}
-
-// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-// a: [OC, IC, KH, KW]
-// b: [N, OH, OW, IC * KH * KW]
-// result: [N, OC, OH, OW]
-static struct ggml_tensor * ggml_conv_2d_stage_1(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b) {
-
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t ne[4] = {
-        b->ne[1],
-        b->ne[2],
-        a->ne[3],
-        b->ne[3],
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-
-    result->op = GGML_OP_CONV_2D_STAGE_1;
+    result->op = GGML_OP_IM2COL;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
 
     return result;
-
 }
 
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
 // result: [N, OC, OH, OW]
 struct ggml_tensor * ggml_conv_2d(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int                  s0,
-    int                  s1,
-    int                  p0,
-    int                  p1,
-    int                  d0,
-    int                  d1) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                  s0,
+        int                  s1,
+        int                  p0,
+        int                  p1,
+        int                  d0,
+        int                  d1) {
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
 
-    struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
-    result = ggml_conv_2d_stage_1(ctx, a, result);
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC,IC, KH, KW] => [OC, IC * KH * KW]
 
-    return result;
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
 
+    return result;
 }
 
 // ggml_conv_2d_sk_p0
@@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     // TODO: find the optimal values for these
     if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
+        src0->type == GGML_TYPE_F32 &&
+        src1->type == GGML_TYPE_F32 &&
         (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
@@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
+
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat(
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -11637,9 +11497,9 @@ static void ggml_compute_forward_rope_back(
     }
 }
 
-// ggml_compute_forward_conv_1d
+// ggml_compute_forward_conv_transpose_1d
 
-static void ggml_compute_forward_conv_1d_f16_f32(
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11656,14 +11516,7 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nk = ne00;
-
-    // size of the convolution row - the kernel size unrolled across all input channels
-    const int ew0 = nk*ne01;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int nk = ne00*ne01*ne02;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -11671,23 +11524,37 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     if (params->type == GGML_TASK_INIT) {
         memset(params->wdata, 0, params->wsize);
 
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-        for (int64_t i11 = 0; i11 < ne11; i11++) {
-            const float * const src = (float *)((char *) src1->data + i11*nb11);
-            ggml_fp16_t * dst_data = wdata;
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
 
-            for (int64_t i0 = 0; i0 < ne0; i0++) {
-                for (int64_t ik = 0; ik < nk; ik++) {
-                    const int idx0 = i0*s0 + ik*d0 - p0;
+        // permute source data (src1) from (L x Cin) to (Cin x L)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            ggml_fp16_t * dst_data = wdata;
 
-                    if(!(idx0 < 0 || idx0 >= ne10)) {
-                        dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
-                    }
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
         }
 
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+
         return;
     }
 
@@ -11695,8 +11562,10 @@ static void ggml_compute_forward_conv_1d_f16_f32(
         return;
     }
 
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
     // total rows in dst
-    const int nr = ne2;
+    const int nr = ne1;
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11705,22 +11574,26 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-    for (int i2 = 0; i2 < ne2; i2++) {
-        for (int i1 = ir0; i1 < ir1; i1++) {
-            float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
+    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                ggml_vec_dot_f16(ew0, dst_data + i0,
-                        (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
-                        (ggml_fp16_t *)                wdata + i2*nb2 + i0*ew0);
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f16(ne02, &v,
+                        (ggml_fp16_t *)    wdata_src + i1n,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02);
+                dst_data[i10*s0 + i00] += v;
             }
         }
     }
 }
 
-static void ggml_compute_forward_conv_1d_f32(
+static void ggml_compute_forward_conv_transpose_1d_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11737,13 +11610,7 @@ static void ggml_compute_forward_conv_1d_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nk = ne00;
-
-    const int ew0 = nk*ne01;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int nk = ne00*ne01*ne02;
 
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -11751,23 +11618,37 @@ static void ggml_compute_forward_conv_1d_f32(
     if (params->type == GGML_TASK_INIT) {
         memset(params->wdata, 0, params->wsize);
 
-        float * const wdata = (float *) params->wdata + 0;
+        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            float * const wdata = (float *) params->wdata + 0;
 
-        for (int64_t i11 = 0; i11 < ne11; i11++) {
-            const float * const src = (float *)((char *) src1->data + i11*nb11);
-            float * dst_data = wdata;
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
 
-            for (int64_t i0 = 0; i0 < ne0; i0++) {
-                for (int64_t ik = 0; ik < nk; ik++) {
-                    const int idx0 = i0*s0 + ik*d0 - p0;
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + nk;
+            float * dst_data = wdata;
 
-                    if(!(idx0 < 0 || idx0 >= ne10)) {
-                        dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
-                    }
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = src[i10];
                 }
             }
         }
 
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+
         return;
     }
 
@@ -11775,8 +11656,10 @@ static void ggml_compute_forward_conv_1d_f32(
         return;
     }
 
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
     // total rows in dst
-    const int nr = ne02;
+    const int nr = ne1;
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11785,94 +11668,50 @@ static void ggml_compute_forward_conv_1d_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    float * const wdata = (float *) params->wdata + 0;
-
-    for (int i2 = 0; i2 < ne2; i2++) {
-        for (int i1 = ir0; i1 < ir1; i1++) {
-            float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
+    float * const wdata     = (float *) params->wdata + 0;
+    float * const wdata_src = wdata + nk;
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                ggml_vec_dot_f32(ew0, dst_data + i0,
-                        (float *) ((char *) src0->data + i1*nb02),
-                        (float *)                wdata + i2*nb2 + i0*ew0);
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        float * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f32(ne02, &v,
+                        wdata_src + i1n,
+                        wdata_kernel + i00*ne02);
+                dst_data[i10*s0 + i00] += v;
             }
         }
     }
 }
 
-// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1
-static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
-                             ggml_fp16_t * A,
-                             ggml_fp16_t * B,
-                             float * C,
-                             const int ith, const int nth) {
-    // does not seem to make a difference
-    int64_t m0, m1, n0, n1;
-    // patches per thread
-    if (m > n) {
-        n0 = 0;
-        n1 = n;
-
-        // total patches in dst
-        const int np = m;
-
-        // patches per thread
-        const int dp = (np + nth - 1)/nth;
-
-        // patch range for this thread
-        m0 = dp*ith;
-        m1 = MIN(m0 + dp, np);
-    } else {
-        m0 = 0;
-        m1 = m;
-
-        // total patches in dst
-        const int np = n;
-
-        // patches per thread
-        const int dp = (np + nth - 1)/nth;
-
-        // patch range for this thread
-        n0 = dp*ith;
-        n1 = MIN(n0 + dp, np);
-    }
-
-    // block-tiling attempt
-    int64_t blck_n = 16;
-    int64_t blck_m = 16;
-
-    // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
-    // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
-    // if (blck_size > 0) {
-    //     blck_0 = 4;
-    //     blck_1 = blck_size / blck_0;
-    //     if (blck_1 < 0) {
-    //         blck_1 = 1;
-    //     }
-    //     // blck_0 = (int64_t)sqrt(blck_size);
-    //     // blck_1 = blck_0;
-    // }
-    // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
-
-    for (int j = n0; j < n1; j+=blck_n) {
-        for (int i = m0; i < m1; i+=blck_m) {
-            // printf("i j k => %d %d %d\n", i, j, K);
-            for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
-                for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
-                    ggml_vec_dot_f16(k,
-                                    C + ii*n + jj,
-                                    A + ii * k,
-                                    B + jj * k);
-                }
-            }
-        }
+static void ggml_compute_forward_conv_transpose_1d(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
     }
 }
 
-// src0: kernel [OC, IC, K]
-// src1: signal [N, IC, IL]
-// dst:  result [N, OL, IC*K]
-static void ggml_compute_forward_conv_1d_stage_0_f32(
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11886,425 +11725,35 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
 
     GGML_TENSOR_BINARY_OP_LOCALS;
 
-    const int64_t N  = ne12;
-    const int64_t IC = ne11;
-    const int64_t IL = ne10;
-
-    const int64_t K = ne00;
-
-    const int64_t OL = ne1;
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
     if (params->type == GGML_TASK_INIT) {
-        memset(dst->data, 0, ggml_nbytes(dst));
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // im2col: [N, IC, IL] => [N, OL, IC*K]
-    {
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t iol = 0; iol < OL; iol++) {
-                for (int64_t iic = ith; iic < IC; iic+=nth) {
-
-                    // micro kernel
-                    ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
-                    const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
-
-                    for (int64_t ik = 0; ik < K; ik++) {
-                        const int64_t iil = iol*s0 + ik*d0 - p0;
-
-                        if (!(iil < 0 || iil >= IL)) {
-                            dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-// src0: [OC, IC, K]
-// src1: [N, OL, IC * K]
-// result: [N, OC, OL]
-static void ggml_compute_forward_conv_1d_stage_1_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    const int N = ne12;
-    const int OL = ne11;
-
-    const int OC = ne02;
-    const int IC = ne01;
-    const int K  = ne00;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t m = OC;
-    int64_t n = OL;
-    int64_t k = IC * K;
-
-    // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m, n]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_1d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_1d_stage_0(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_1d_stage_1(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-// ggml_compute_forward_conv_transpose_1d
-
-static void ggml_compute_forward_conv_transpose_1d_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (L x Cin) to (Cin x L)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            ggml_fp16_t * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f16(ne02, &v,
-                        (ggml_fp16_t *)    wdata_src + i1n,
-                        (ggml_fp16_t *) wdata_kernel + i00*ne02);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            float * const wdata = (float *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    float * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // prepare source data (src1)
-        {
-            float * const wdata = (float *) params->wdata + nk;
-            float * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = src[i10];
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * const wdata     = (float *) params->wdata + 0;
-    float * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        float * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f32(ne02, &v,
-                        wdata_src + i1n,
-                        wdata_kernel + i00*ne02);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-// ggml_compute_forward_conv_2d
-
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_conv_2d_stage_0_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int64_t N = ne13;
-    const int64_t IC = ne12;
-    const int64_t IH = ne11;
-    const int64_t IW = ne10;
-
-    // const int64_t OC = ne03;
-    // const int64_t IC = ne02;
-    const int64_t KH = ne01;
-    const int64_t KW = ne00;
-
-    const int64_t OH = ne2;
-    const int64_t OW = ne1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(dst->data, 0, ggml_nbytes(dst));
         return;
     }
 
@@ -12317,20 +11766,22 @@ static void ggml_compute_forward_conv_2d_stage_0_f32(
         ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
 
         for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
                 for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic+=nth) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
 
                         // micro kernel
                         ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
 
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
                             for (int64_t ikw = 0; ikw < KW; ikw++) {
                                 const int64_t iiw = iow*s0 + ikw*d0 - p0;
                                 const int64_t iih = ioh*s1 + ikh*d1 - p1;
 
-                                if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
                                     dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
                                 }
                             }
@@ -12342,223 +11793,7 @@ static void ggml_compute_forward_conv_2d_stage_0_f32(
     }
 }
 
-// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-// src0: [OC, IC, KH, KW]
-// src1: [N, OH, OW, IC * KH * KW]
-// result: [N, OC, OH, OW]
-static void ggml_compute_forward_conv_2d_stage_1_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    const int N = ne13;
-    const int OH = ne12;
-    const int OW = ne11;
-
-    const int OC = ne03;
-    const int IC = ne02;
-    const int KH = ne01;
-    const int KW = ne00;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t m = OC;
-    int64_t n = OH * OW;
-    int64_t k = IC * KH * KW;
-
-    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m, n]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_2d_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    // src1: image [N, IC, IH, IW]
-    // src0: kernel [OC, IC, KH, KW]
-    // dst:  result [N, OC, OH, OW]
-    // ne12: IC
-    // ne0: OW
-    // ne1: OH
-    // nk0: KW
-    // nk1: KH
-    // ne13: N
-
-    const int N = ne13;
-    const int IC = ne12;
-    const int IH = ne11;
-    const int IW = ne10;
-
-    const int OC = ne03;
-    // const int IC = ne02;
-    const int KH = ne01;
-    const int KW = ne00;
-
-    const int OH = ne1;
-    const int OW = ne0;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // const int nk0 = ne00;
-    // const int nk1 = ne01;
-
-    // size of the convolution row - the kernel size unrolled across all channels
-    // const int ew0 = nk0*nk1*ne02;
-    // ew0: IC*KH*KW
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare source data (src1)
-        // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
-
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int in = 0; in < N; in++) {
-                for (int iic = 0; iic < IC; iic++) {
-                    for (int ioh = 0; ioh < OH; ioh++) {
-                        for (int iow = 0; iow < OW; iow++) {
-
-                            // micro kernel
-                            ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                            const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
-
-                            for (int ikh = 0; ikh < KH; ikh++) {
-                                for (int ikw = 0; ikw < KW; ikw++) {
-                                    const int iiw = iow*s0 + ikw*d0 - p0;
-                                    const int iih = ioh*s1 + ikh*d1 - p1;
-
-                                    if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
-                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    // wdata: [N*OH*OW, IC*KH*KW]
-    // dst: result [N, OC, OH, OW]
-    // src0: kernel [OC, IC, KH, KW]
-
-    int64_t m = OC;
-    int64_t n = OH * OW;
-    int64_t k = IC * KH * KW;
-
-    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m * k]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_2d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
-                GGML_ASSERT(false);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_2d_stage_0(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(false);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_2d_stage_1(
+static void ggml_compute_forward_im2col(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -12566,7 +11801,7 @@ static void ggml_compute_forward_conv_2d_stage_1(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
+                ggml_compute_forward_im2col_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
@@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_clamp(params, tensor->src[0], tensor);
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
-                ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
@@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
@@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = 1; //TODO
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 n_tasks = n_threads;
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
                 n_tasks = n_threads;
             } break;
@@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         default:
             {
+                printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
                 GGML_ASSERT(false);
             } break;
     }
@@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
-            case GGML_OP_CONV_1D:
-                {
-                    GGML_ASSERT(node->src[0]->ne[3] == 1);
-                    GGML_ASSERT(node->src[1]->ne[2] == 1);
-                    GGML_ASSERT(node->src[1]->ne[3] == 1);
-
-                    const int64_t ne00 = node->src[0]->ne[0];
-                    const int64_t ne01 = node->src[0]->ne[1];
-                    const int64_t ne02 = node->src[0]->ne[2];
-
-                    const int64_t ne10 = node->src[1]->ne[0];
-                    const int64_t ne11 = node->src[1]->ne[1];
-
-                    const int64_t ne0 = node->ne[0];
-                    const int64_t ne1 = node->ne[1];
-                    const int64_t nk  = ne00;
-                    const int64_t ew0 = nk * ne01;
-
-                    UNUSED(ne02);
-                    UNUSED(ne10);
-                    UNUSED(ne11);
-
-                    if (node->src[0]->type == GGML_TYPE_F16 &&
-                        node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
-                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                               node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(float)*(ne0*ne1*ew0);
-                    } else {
-                        GGML_ASSERT(false);
-                    }
-                } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
                 } break;
-            case GGML_OP_CONV_2D:
+            case GGML_OP_IM2COL:
                 {
-                    const int64_t ne00 = node->src[0]->ne[0]; // W
-                    const int64_t ne01 = node->src[0]->ne[1]; // H
-                    const int64_t ne02 = node->src[0]->ne[2]; // C
-                    const int64_t ne03 = node->src[0]->ne[3]; // N
-
-                    const int64_t ne10 = node->src[1]->ne[0]; // W
-                    const int64_t ne11 = node->src[1]->ne[1]; // H
-                    const int64_t ne12 = node->src[1]->ne[2]; // C
-
-                    const int64_t ne0 = node->ne[0];
-                    const int64_t ne1 = node->ne[1];
-                    const int64_t ne2 = node->ne[2];
-                    const int64_t ne3 = node->ne[3];
-                    const int64_t nk = ne00*ne01;
-                    const int64_t ew0 = nk * ne02;
-
-                    UNUSED(ne03);
-                    UNUSED(ne2);
-
-                    if (node->src[0]->type == GGML_TYPE_F16 &&
-                        node->src[1]->type == GGML_TYPE_F32) {
-                        // im2col: [N*OH*OW, IC*KH*KW]
-                        cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
-                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                               node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(float)*      (ne10*ne11*ne12);
-                    } else {
-                        GGML_ASSERT(false);
-                    }
+                    n_tasks = n_threads;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
index e0130cda06f73a2689093220afaade7fa33f414a..e069e4e6deecf2c0022a99df6af6a4370cb780c9 100644 (file)
@@ -355,3 +355,32 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+#
+# test-conv1d
+
+set(TEST_TARGET test-conv1d)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+
+#
+# test-conv2d
+
+set(TEST_TARGET test-conv2d)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+
+#
+# test-mul-mat
+
+set(TEST_TARGET test-mul-mat)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
diff --git a/tests/test-conv1d.cpp b/tests/test-conv1d.cpp
new file mode 100644 (file)
index 0000000..3067f9d
--- /dev/null
@@ -0,0 +1,298 @@
+#include "ggml.h"
+#include "ggml/ggml-alloc.h"
+#include "ggml/ggml-backend.h"
+
+// #define GGML_USE_CUBLAS
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+struct test_model {
+    struct ggml_tensor * a;
+    struct ggml_tensor * b;
+    ggml_backend_t backend = NULL;
+    ggml_backend_buffer_t buffer;
+    struct ggml_context * ctx;
+};
+
+void load_model(test_model & model, bool use_gpu = false) {
+    // create data
+    int K = 3, IC = 10, OC = 10;
+    int IL = 8, N = 1;
+
+    // Initialize adata
+    float* adata = new float[K * IC * OC];
+    for (size_t i = 0; i < K * IC * OC; i++) {
+        adata[i] = 4.5f;
+    }
+
+    // Convert adata to fp16 format
+    std::vector<ggml_fp16_t> hadata(K * IC * OC);
+    ggml_fp32_to_fp16_row(adata, hadata.data(), K * IC * OC);
+
+    // Initialize bdata
+    float* bdata =  new float[IL * IC * N];
+    for (size_t i = 0; i < IL * IC * N; i++) {
+        bdata[i] = 2.5f;
+    }
+
+    size_t buffer_size = 0;
+    {
+        buffer_size += K * IC * OC * ggml_type_sizef(GGML_TYPE_F16); // tensor a
+        buffer_size += IL * IC * N * ggml_type_sizef(GGML_TYPE_F32); // tensor b
+        buffer_size += 1024; // overhead
+    }
+
+    printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+    printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f));
+
+    int num_tensors = 2;
+    struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+    };
+
+    // initialize the backend
+#ifdef GGML_USE_CUBLAS
+    if (use_gpu) {
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        model.backend = ggml_backend_cuda_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (use_gpu) {
+        fprintf(stderr, "%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
+        model.backend = ggml_backend_metal_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if(!model.backend) {
+        // fallback to CPU backend
+        model.backend = ggml_backend_cpu_init();
+    }
+
+    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
+
+    // create context
+    model.ctx = ggml_init(params);
+
+    // create tensors
+    model.a = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16,  K, IC, OC);
+    model.b = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, IL, IC, N);
+
+    // create a allocator
+    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.a);
+
+    // load data to buffer
+    if(ggml_backend_is_cpu(model.backend)) {
+        memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));
+    } else {
+        ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));
+    }
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.b);
+
+    if(ggml_backend_is_cpu(model.backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(model.backend)
+#endif
+    ) {
+        memcpy(model.b->data, bdata, ggml_nbytes(model.b));
+    } else {
+        ggml_backend_tensor_set(model.b, bdata, 0, ggml_nbytes(model.b));
+    }
+
+    ggml_allocr_free(alloc);
+}
+
+struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) {
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+
+    // create a temporally context to build the graph
+    struct ggml_context * ctx0 = ggml_init(params0);
+
+    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+
+    int s0 = 1;
+    int p0 = 1;
+    int d0 = 1;
+
+    // split conv1d in fundamental methods for test unit
+    struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, 0, p0, 0, d0, 0, false);
+    ggml_set_name(im2col_0, "im2col_res");
+    ggml_build_forward_expand(gf, im2col_0);
+
+    struct ggml_tensor* conv1d_res = ggml_conv_1d(ctx0, model.a, model.b, s0, p0, d0);
+    ggml_set_name(conv1d_res, "conv1d_res");
+    ggml_build_forward_expand(gf, conv1d_res);
+
+    // delete the temporally context used to build the graph
+    ggml_free(ctx0);
+    return gf;
+}
+
+struct ggml_cgraph* compute_graph(const test_model & model, struct ggml_allocr * allocr) {
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = build_graph(model, allocr);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
+    int n_threads = 1;
+
+    if (ggml_backend_is_cpu(model.backend)) {
+        ggml_backend_cpu_set_n_threads(model.backend, n_threads);
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(model.backend)) {
+        ggml_backend_metal_set_n_cb(model.backend, n_threads);
+    }
+#endif
+
+    ggml_backend_graph_compute(model.backend, gf);
+
+    //ggml_graph_print(gf);
+
+    // in this case, the output tensor is the last one in the graph
+    return gf;
+}
+
+int main(void)
+{
+    ggml_time_init();
+
+    test_model model;
+    load_model(model, true);
+
+    ggml_backend_buffer_t buf_compute; // for compute
+    struct ggml_allocr * allocr = NULL;
+
+    {
+        size_t align = ggml_backend_get_alignment(model.backend);
+        allocr = ggml_allocr_new_measure(align);
+
+        //create the worst case graph for memory usage estimation
+        struct ggml_cgraph * gf = build_graph(model, allocr);
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
+        ggml_allocr_free(allocr);
+
+        // compute the required memory
+        buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
+        allocr = ggml_allocr_new_from_buffer(buf_compute);
+        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
+    }
+
+    struct ggml_cgraph * gf_res = compute_graph(model, allocr);
+
+     struct ggml_tensor * im2col_res = NULL;
+    struct ggml_tensor * conv1d_res = NULL;
+
+    for(int i = 0; i < gf_res->n_nodes; i++) {
+        if(strcmp(ggml_get_name(gf_res->nodes[i]), "im2col_res") == 0) {
+            im2col_res = gf_res->nodes[i];
+        } else if(strcmp(ggml_get_name(gf_res->nodes[i]), "conv1d_res") == 0) {
+            conv1d_res = gf_res->nodes[i];
+        }
+    }
+
+    uint16_t* im2col_data = new uint16_t[ggml_nelements(im2col_res)];
+    float* conv2d_data = new float[ggml_nelements(conv1d_res)];
+
+    ggml_backend_tensor_get(im2col_res, im2col_data, 0, ggml_nbytes(im2col_res));
+    ggml_backend_tensor_get(conv1d_res, conv2d_data, 0, ggml_nbytes(conv1d_res));
+
+    const int n_conv1d_test = 80;
+    const int n_im2col_test = 240;
+
+    float expected_conv1d[n_conv1d_test] = {
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f
+    };
+    // first im2col test
+
+    uint16_t expected_im2col[n_conv1d_test] = {
+        0, 16640, 16640, 0, 16640, 16640, 0, 16640,
+        16640, 0, 16640, 16640, 0, 16640, 16640, 0,
+        16640, 16640, 0, 16640, 16640, 0, 16640, 16640,
+        0, 16640, 16640, 0, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,
+        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640
+    };
+
+    printf("\nPerforming test:\n");
+
+    bool passed = true;
+    for(int i = 0; i < n_conv1d_test; i++) {
+        if(
+            im2col_data[i] != expected_im2col[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    printf("ggml_im2col (%i): %s\n", ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+
+    passed = true;
+    for(int i = 0; i < n_conv1d_test; i++) {
+        if(conv2d_data[i] != expected_conv1d[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    printf("ggml_conv1d (%i): %s\n", ggml_nelements(conv1d_res), passed && (ggml_nelements(conv1d_res) == n_conv1d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+    ggml_free(model.ctx);
+
+    ggml_backend_buffer_free(model.buffer);
+    ggml_backend_buffer_free(buf_compute);
+    ggml_backend_free(model.backend);
+    return 0;
+}
diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp
new file mode 100644 (file)
index 0000000..4ad830b
--- /dev/null
@@ -0,0 +1,401 @@
+#include "ggml.h"
+#include "ggml/ggml-alloc.h"
+#include "ggml/ggml-backend.h"
+
+// #define GGML_USE_CUBLAS
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+struct test_model {
+    struct ggml_tensor * a;
+    struct ggml_tensor * b;
+    ggml_backend_t backend = NULL;
+    ggml_backend_buffer_t buffer;
+    struct ggml_context * ctx;
+};
+
+void load_model(test_model & model, bool use_gpu = false) {
+    // create data
+    int KW = 3, KH = 3, IC = 10, OC = 10;
+    int IW = 8, IH = 6, N = 1;
+
+    // Initialize adata
+    float* adata = new float[KW * KH * IC * OC];
+    for (size_t i = 0; i < KW * KH * IC * OC; i++) {
+        adata[i] = 2.5f;
+    }
+
+    // Convert adata to fp16 format
+    std::vector<ggml_fp16_t> hadata(KW * KH * IC * OC);
+    ggml_fp32_to_fp16_row(adata, hadata.data(), KW * KH * IC * OC);
+
+    // Initialize bdata
+    float* bdata =  new float[IW * IH * IC * N];
+    for (size_t i = 0; i < IW * IH * IC * N; i++) {
+        bdata[i] = 1.5f;
+    }
+
+    size_t buffer_size = 0;
+    {
+        buffer_size += KW * KH * IC * OC * ggml_type_sizef(GGML_TYPE_F16); // tensor a
+        buffer_size += IW * IH * IC * N * ggml_type_sizef(GGML_TYPE_F32); // tensor b
+        buffer_size += 1024; // overhead
+    }
+
+    printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+    printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f));
+
+    int num_tensors = 2;
+    struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+    };
+
+    // initialize the backend
+#ifdef GGML_USE_CUBLAS
+    if (use_gpu) {
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        model.backend = ggml_backend_cuda_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (use_gpu) {
+        fprintf(stderr, "%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
+        model.backend = ggml_backend_metal_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if(!model.backend) {
+        // fallback to CPU backend
+        model.backend = ggml_backend_cpu_init();
+    }
+
+    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
+
+    // create context
+    model.ctx = ggml_init(params);
+
+    // create tensors
+    model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16,  KW, KH, IC, OC);
+    model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N);
+
+    // create a allocator
+    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.a);
+
+    // load data to buffer
+    if(ggml_backend_is_cpu(model.backend)) {
+        memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));
+    } else {
+        ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));
+    }
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.b);
+
+    if(ggml_backend_is_cpu(model.backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(model.backend)
+#endif
+    ) {
+        memcpy(model.b->data, bdata, ggml_nbytes(model.b));
+    } else {
+        ggml_backend_tensor_set(model.b, bdata, 0, ggml_nbytes(model.b));
+    }
+
+    ggml_allocr_free(alloc);
+}
+
+struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) {
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+
+    // create a temporally context to build the graph
+    struct ggml_context * ctx0 = ggml_init(params0);
+
+    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+
+    int s0 = 1;
+    int s1 = 1;
+    int p0 = 1;
+    int p1 = 1;
+    int d0 = 1;
+    int d1 = 1;
+
+    // split conv2d in fundamental methods for test unit
+    struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true);
+    ggml_set_name(im2col_0, "im2col_res");
+    ggml_build_forward_expand(gf, im2col_0);
+
+    // recalculate for avoid fragmentation
+    struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
+    ggml_set_name(conv2d_res, "conv2d_res");
+    ggml_build_forward_expand(gf, conv2d_res);
+
+    // delete the temporally context used to build the graph
+    ggml_free(ctx0);
+    return gf;
+}
+
+struct ggml_cgraph * compute_graph(const test_model & model, struct ggml_allocr * allocr) {
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = build_graph(model, allocr);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
+    int n_threads = 1;
+
+    if (ggml_backend_is_cpu(model.backend)) {
+        ggml_backend_cpu_set_n_threads(model.backend, n_threads);
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(model.backend)) {
+        ggml_backend_metal_set_n_cb(model.backend, n_threads);
+    }
+#endif
+
+    ggml_backend_graph_compute(model.backend, gf);
+
+    //ggml_graph_print(gf);
+
+    // in this case, the output tensor is the last one in the graph
+    return gf;
+}
+
+int main(void)
+{
+    ggml_time_init();
+
+    test_model model;
+    load_model(model, true);
+
+    ggml_backend_buffer_t buf_compute; // for compute
+    struct ggml_allocr * allocr = NULL;
+
+    {
+        size_t align = ggml_backend_get_alignment(model.backend);
+        allocr = ggml_allocr_new_measure(align);
+
+        //create the worst case graph for memory usage estimation
+        struct ggml_cgraph * gf = build_graph(model, allocr);
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
+        ggml_allocr_free(allocr);
+
+        // compute the required memory
+        buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
+        allocr = ggml_allocr_new_from_buffer(buf_compute);
+        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
+    }
+
+    struct ggml_cgraph * gf_res = compute_graph(model, allocr);
+
+    struct ggml_tensor * im2col_res = NULL;
+    struct ggml_tensor * conv2d_res = NULL;
+
+    for(int i = 0; i < gf_res->n_nodes; i++) {
+        if(strcmp(ggml_get_name(gf_res->nodes[i]), "im2col_res") == 0) {
+            im2col_res = gf_res->nodes[i];
+        } else if(strcmp(ggml_get_name(gf_res->nodes[i]), "conv2d_res") == 0) {
+            conv2d_res = gf_res->nodes[i];
+        }
+    }
+
+    uint16_t* im2col_data = new uint16_t[ggml_nelements(im2col_res)];
+    float* conv2d_data = new float[ggml_nelements(conv2d_res)];
+
+    ggml_backend_tensor_get(im2col_res, im2col_data, 0, ggml_nbytes(im2col_res));
+    ggml_backend_tensor_get(conv2d_res, conv2d_data, 0, ggml_nbytes(conv2d_res));
+
+    const int n_conv2d_test = 480;
+    const int n_im2col_test = 4320;
+
+    float expected_conv2d [n_conv2d_test] = {
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
+        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f };
+
+    uint16_t expected_im2col[n_conv2d_test] = {
+            0, 0, 0, 0, 15872, 15872, 0, 15872,
+            15872, 0, 0, 0, 0, 15872, 15872, 0,
+            15872, 15872, 0, 0, 0, 0, 15872, 15872,
+            0, 15872, 15872, 0, 0, 0, 0, 15872,
+            15872, 0, 15872, 15872, 0, 0, 0, 0,
+            15872, 15872, 0, 15872, 15872, 0, 0, 0,
+            0, 15872, 15872, 0, 15872, 15872, 0, 0,
+            0, 0, 15872, 15872, 0, 15872, 15872, 0,
+            0, 0, 0, 15872, 15872, 0, 15872, 15872,
+            0, 0, 0, 0, 15872, 15872, 0, 15872,
+            15872, 0, 0, 0, 0, 15872, 15872, 0,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0,
+            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
+            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
+            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
+            0, 0, 0, 15872, 15872, 15872, 15872, 15872,
+            15872, 0, 0, 0, 15872, 15872, 15872, 15872,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0,
+            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
+            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
+            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
+            0, 0, 0, 15872, 15872, 15872, 15872, 15872,
+            15872, 0, 0, 0, 15872, 15872, 15872, 15872,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0,
+            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
+            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
+            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
+            0, 0, 0, 15872, 15872, 15872, 15872, 15872,
+            15872, 0, 0, 0, 15872, 15872, 15872, 15872,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0,
+            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
+            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
+            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
+            0, 0, 0, 15872, 15872, 15872, 15872, 15872,
+            15872, 0, 0, 0, 15872, 15872, 15872, 15872,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0,
+            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
+            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
+            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
+            0, 0, 0, 15872, 15872, 15872, 15872, 15872,
+            15872, 0, 0, 0, 15872, 15872, 15872, 15872,
+            15872, 15872, 0, 0, 0, 15872, 15872, 15872,
+            15872, 15872, 15872, 0, 0, 0, 15872, 15872,
+            15872, 15872, 15872, 15872, 0, 0, 0, 15872,
+            15872, 15872, 15872, 15872, 15872, 0, 0, 0
+    };
+
+    printf("\nPerforming test:\n");
+
+    bool passed = true;
+    for(int i = 0; i < n_conv2d_test; i++) {
+        if(
+            im2col_data[i] != expected_im2col[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+
+    passed = true;
+    for(int i = 0; i < n_conv2d_test; i++) {
+        if(conv2d_data[i] != expected_conv2d[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+
+    ggml_free(model.ctx);
+
+    ggml_backend_buffer_free(model.buffer);
+    ggml_backend_buffer_free(buf_compute);
+    ggml_backend_free(model.backend);
+    return 0;
+}
diff --git a/tests/test-mul-mat.cpp b/tests/test-mul-mat.cpp
new file mode 100644 (file)
index 0000000..36b7c6b
--- /dev/null
@@ -0,0 +1,363 @@
+#include "ggml.h"
+#include "ggml/ggml-alloc.h"
+#include "ggml/ggml-backend.h"
+
+//#define GGML_USE_CUBLAS // uncomment this to use cuda backend, make sure build ggml lib with GGML_CUBLAS=ON
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+struct test_model {
+    struct ggml_tensor * a;
+    struct ggml_tensor * b;
+    ggml_backend_t backend = NULL;
+    ggml_backend_buffer_t buffer;
+    struct ggml_context * ctx;
+};
+
+void load_model(test_model & model, float* a, float* b, int M, int N, int K, bool use_gpu = false) {
+    size_t buffer_size = 0;
+    {
+        buffer_size += (M * N) * ggml_type_sizef(GGML_TYPE_F32); // tensor a
+        buffer_size += (N * K) * ggml_type_sizef(GGML_TYPE_F32); // tensor b
+        buffer_size += 1024; // overhead
+    }
+
+    printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+    printf("%s: backend buffer size = %d bytes\n", __func__, (int) buffer_size);
+
+    int num_tensors = 2;
+    struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+    };
+
+    // initialize the backend
+#ifdef GGML_USE_CUBLAS
+    if (use_gpu) {
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        model.backend = ggml_backend_cuda_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (use_gpu) {
+        fprintf(stderr, "%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
+        model.backend = ggml_backend_metal_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if(!model.backend) {
+        // fallback to CPU backend
+        model.backend = ggml_backend_cpu_init();
+    }
+
+    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
+
+    // create context
+    model.ctx = ggml_init(params);
+
+    // create tensors
+    model.a = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, M);
+    printf("Matrix A: [%i, %i]\n", K, M);
+    model.b = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, N);
+    printf("Matrix B: [%i, %i]\n", K, N);
+
+    // create a allocator
+    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.a);
+
+    // load data to buffer
+    if(ggml_backend_is_cpu(model.backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(model.backend)
+#endif
+    ) {
+        memcpy(model.a->data, a, ggml_nbytes(model.a));
+    } else {
+        ggml_backend_tensor_set(model.a, a, 0, ggml_nbytes(model.a)); // cuda requires copy the data directly to device
+    }
+
+    // alloc memory
+    ggml_allocr_alloc(alloc, model.b);
+
+    if(ggml_backend_is_cpu(model.backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(model.backend)
+#endif
+    ) {
+        memcpy(model.b->data, b, ggml_nbytes(model.b));
+    } else {
+        ggml_backend_tensor_set(model.b, b, 0, ggml_nbytes(model.b));  // cuda requires copy the data directly to device
+    }
+
+    ggml_allocr_free(alloc);
+}
+
+struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) {
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+
+    // create a temporally context to build the graph
+    struct ggml_context * ctx0 = ggml_init(params0);
+
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    // zT = x @ yT
+    struct ggml_tensor * result = ggml_mul_mat(ctx0, model.a, ggml_cont(ctx0, model.b));
+
+    // z = (zT)T
+    ggml_build_forward_expand(gf, ggml_cont(ctx0, ggml_transpose(ctx0, result)));
+
+    // delete the temporally context used to build the graph
+    ggml_free(ctx0);
+    return gf;
+}
+
+struct ggml_tensor* compute(const test_model & model, struct ggml_allocr * allocr) {
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = build_graph(model, allocr);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
+    int n_threads = 1;
+
+    if (ggml_backend_is_cpu(model.backend)) {
+        ggml_backend_cpu_set_n_threads(model.backend, n_threads);
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(model.backend)) {
+        ggml_backend_metal_set_n_cb(model.backend, n_threads);
+    }
+#endif
+
+    ggml_backend_graph_compute(model.backend, gf);
+
+    //ggml_graph_print(gf);
+
+    // in this case, the output tensor is the last one in the graph
+    return gf->nodes[gf->n_nodes - 1];
+}
+
+
+static void ggml_vec_dot_f16(const int n, float * s, float * x, float * y) {
+    float sumf = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sumf += x[i] * y[i];
+    }
+    *s = sumf;
+}
+
+static void gemm_f16_out_f32(int m, int n, int k,
+                             float * A,
+                             float * B,
+                             float * C,
+                             const int ith, const int nth) {
+    // does not seem to make a difference
+    int m0, m1, n0, n1;
+    // patches per thread
+    if (m > n) {
+        n0 = 0;
+        n1 = n;
+
+        // total patches in dst
+        const int np = m;
+
+        // patches per thread
+        const int dp = (np + nth - 1)/nth;
+
+        // patch range for this thread
+        m0 = dp*ith;
+        m1 = std::min(m0 + dp, np);
+    } else {
+        m0 = 0;
+        m1 = m;
+
+        // total patches in dst
+        const int np = n;
+
+        // patches per thread
+        const int dp = (np + nth - 1)/nth;
+
+        // patch range for this thread
+        n0 = dp*ith;
+        n1 = std::min(n0 + dp, np);
+    }
+
+    // block-tiling attempt
+    int64_t blck_n = 16;
+    int64_t blck_m = 16;
+
+    for (int j = n0; j < n1; j+=blck_n) {
+        for (int i = m0; i < m1; i+=blck_m) {
+            // printf("i j k => %d %d %d\n", i, j, K);
+            for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
+                for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
+                    ggml_vec_dot_f16(k,
+                                    C + ii*n + jj,
+                                    A + ii * k,
+                                    B + jj * k);
+                }
+            }
+        }
+    }
+}
+
+
+void perform_gemm_test(float* a, float* b, float* expected, int M, int N, int K) {
+    printf("\nPerforming gemm_f16_out_f32 test:\n");
+
+    float* gemm_out = new float[M * N];
+    gemm_f16_out_f32(M, N, K, a, b, gemm_out, 0, 1);
+
+    for (int i = 0; i < M; i++) {
+        for (int j = 0; j < N; j++) {
+            printf("%.1ff,", gemm_out[i * N + j]);
+        }
+        printf("\n");
+    }
+
+    bool passed = true;
+
+    for(int i = 0; i < M * N; i++) {
+        if(gemm_out[i] != expected[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    printf("gemm_mult (%i): %s\n", (M * N), passed ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+}
+
+int main(void)
+{
+    ggml_time_init();
+    const int M = 4, N = 16, K = 36;  // a conv2d expected matrix multiplication
+
+    // matrix A (4 X 36)
+    float matrixA[M * K] = {
+       2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f,
+        10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f,
+        7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f,
+        5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f,
+    };
+
+    // matrix B (16 X 36)
+    float matrixB[N * K] = {
+        9.0f, 7.0f, 1.0f, 3.0f, 5.0f, 9.0f, 7.0f, 6.0f, 1.0f, 10.0f, 1.0f, 1.0f, 7.0f, 2.0f, 4.0f, 9.0f, 10.0f, 4.0f, 5.0f, 5.0f, 7.0f, 1.0f, 7.0f, 7.0f, 2.0f, 9.0f, 5.0f, 10.0f, 7.0f, 4.0f, 8.0f, 9.0f, 9.0f, 3.0f, 10.0f, 2.0f,
+        4.0f, 6.0f, 10.0f, 9.0f, 5.0f, 1.0f, 8.0f, 7.0f, 4.0f, 7.0f, 2.0f, 6.0f, 5.0f, 3.0f, 1.0f, 10.0f, 8.0f, 4.0f, 8.0f, 3.0f, 7.0f, 1.0f, 2.0f, 7.0f, 6.0f, 8.0f, 6.0f, 5.0f, 2.0f, 3.0f, 1.0f, 1.0f, 2.0f, 5.0f, 7.0f, 1.0f,
+        8.0f, 2.0f, 8.0f, 8.0f, 8.0f, 8.0f, 4.0f, 4.0f, 6.0f, 10.0f, 10.0f, 9.0f, 2.0f, 9.0f, 3.0f, 7.0f, 7.0f, 1.0f, 4.0f, 9.0f, 1.0f, 2.0f, 3.0f, 6.0f, 1.0f, 10.0f, 5.0f, 8.0f, 9.0f, 4.0f, 6.0f, 2.0f, 3.0f, 1.0f, 2.0f, 7.0f,
+        5.0f, 1.0f, 7.0f, 2.0f, 9.0f, 10.0f, 9.0f, 5.0f, 2.0f, 5.0f, 4.0f, 10.0f, 9.0f, 9.0f, 1.0f, 9.0f, 8.0f, 8.0f, 9.0f, 4.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 8.0f, 4.0f, 5.0f, 10.0f, 7.0f, 6.0f, 2.0f, 1.0f, 10.0f, 10.0f, 7.0f,
+        9.0f, 4.0f, 5.0f, 9.0f, 5.0f, 10.0f, 10.0f, 3.0f, 6.0f, 6.0f, 4.0f, 4.0f, 4.0f, 8.0f, 5.0f, 4.0f, 9.0f, 1.0f, 9.0f, 9.0f, 1.0f, 7.0f, 9.0f, 2.0f, 10.0f, 9.0f, 10.0f, 8.0f, 3.0f, 3.0f, 9.0f, 3.0f, 9.0f, 10.0f, 1.0f, 8.0f,
+        9.0f, 2.0f, 6.0f, 9.0f, 7.0f, 2.0f, 3.0f, 5.0f, 3.0f, 6.0f, 9.0f, 7.0f, 3.0f, 7.0f, 6.0f, 4.0f, 10.0f, 3.0f, 5.0f, 7.0f, 2.0f, 9.0f, 3.0f, 2.0f, 2.0f, 10.0f, 8.0f, 7.0f, 3.0f, 10.0f, 6.0f, 3.0f, 1.0f, 1.0f, 4.0f, 10.0f,
+        2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f,
+        10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f,
+        7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f,
+        5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f,
+        6.0f, 2.0f, 3.0f, 3.0f, 3.0f, 7.0f, 5.0f, 1.0f, 8.0f, 1.0f, 4.0f, 5.0f, 1.0f, 1.0f, 6.0f, 4.0f, 2.0f, 1.0f, 7.0f, 8.0f, 6.0f, 1.0f, 1.0f, 5.0f, 6.0f, 5.0f, 10.0f, 6.0f, 7.0f, 5.0f, 9.0f, 3.0f, 2.0f, 7.0f, 9.0f, 4.0f,
+        2.0f, 5.0f, 9.0f, 5.0f, 10.0f, 3.0f, 1.0f, 8.0f, 1.0f, 7.0f, 1.0f, 8.0f, 1.0f, 6.0f, 7.0f, 8.0f, 4.0f, 9.0f, 5.0f, 10.0f, 3.0f, 7.0f, 6.0f, 8.0f, 8.0f, 5.0f, 6.0f, 8.0f, 10.0f, 9.0f, 4.0f, 1.0f, 3.0f, 3.0f, 4.0f, 7.0f,
+        8.0f, 2.0f, 6.0f, 6.0f, 5.0f, 1.0f, 3.0f, 7.0f, 1.0f, 7.0f, 2.0f, 2.0f, 2.0f, 8.0f, 4.0f, 1.0f, 1.0f, 5.0f, 9.0f, 4.0f, 1.0f, 2.0f, 3.0f, 10.0f, 1.0f, 4.0f, 9.0f, 9.0f, 6.0f, 8.0f, 8.0f, 1.0f, 9.0f, 10.0f, 4.0f, 1.0f,
+        8.0f, 5.0f, 8.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 1.0f, 9.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 5.0f, 6.0f, 7.0f, 3.0f, 1.0f, 4.0f, 6.0f, 7.0f, 7.0f, 7.0f, 8.0f, 7.0f, 8.0f, 8.0f, 2.0f, 10.0f, 2.0f, 7.0f, 3.0f, 8.0f, 3.0f,
+        8.0f, 7.0f, 6.0f, 2.0f, 4.0f, 10.0f, 10.0f, 6.0f, 10.0f, 3.0f, 7.0f, 6.0f, 4.0f, 3.0f, 5.0f, 5.0f, 5.0f, 3.0f, 8.0f, 10.0f, 3.0f, 4.0f, 8.0f, 4.0f, 2.0f, 6.0f, 8.0f, 9.0f, 6.0f, 9.0f, 4.0f, 3.0f, 5.0f, 2.0f, 2.0f, 6.0f,
+        10.0f, 6.0f, 2.0f, 1.0f, 7.0f, 5.0f, 6.0f, 4.0f, 1.0f, 9.0f, 10.0f, 2.0f, 4.0f, 5.0f, 8.0f, 5.0f, 7.0f, 4.0f, 7.0f, 6.0f, 3.0f, 9.0f, 2.0f, 1.0f, 4.0f, 2.0f, 6.0f, 6.0f, 3.0f, 3.0f, 2.0f, 8.0f, 5.0f, 9.0f, 3.0f, 4.0f,
+    };
+
+    // matrix C (4 x 16)
+    float expected_result[M * N] = {
+        1224.0f, 1023.0f, 1158.0f,1259.0f,1359.0f,1194.0f,1535.0f,1247.0f,1185.0f,1029.0f,889.0f,1182.0f,955.0f,1179.0f,1147.0f,1048.0f,
+        1216.0f, 1087.0f, 1239.0f,1361.0f,1392.0f,1260.0f,1247.0f,1563.0f,1167.0f,1052.0f,942.0f,1214.0f,1045.0f,1134.0f,1264.0f,1126.0f,
+        1125.0f, 966.0f, 1079.0f,1333.0f,1287.0f,1101.0f,1185.0f,1167.0f,1368.0f,990.0f,967.0f,1121.0f,971.0f,1086.0f,1130.0f,980.0f,
+        999.0f, 902.0f, 1020.0f,1056.0f,1076.0f,929.0f,1029.0f,1052.0f,990.0f,1108.0f,823.0f,989.0f,759.0f,1041.0f,1003.0f,870.0f
+    };
+
+    bool passed = true;
+
+    perform_gemm_test(matrixA, matrixB, expected_result, M, N, K);
+
+    test_model model;
+    load_model(model, matrixA, matrixB, M, N, K, true);
+
+    ggml_backend_buffer_t buf_compute; // for compute
+    struct ggml_allocr * allocr = NULL;
+
+    {
+        size_t align = ggml_backend_get_alignment(model.backend);
+        allocr = ggml_allocr_new_measure(align);
+
+        //create the worst case graph for memory usage estimation
+        struct ggml_cgraph * gf = build_graph(model, allocr);
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
+        ggml_allocr_free(allocr);
+
+        // compute the required memory
+        buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
+        allocr = ggml_allocr_new_from_buffer(buf_compute);
+        fprintf(stderr, "%s: compute buffer size: %.4f KB\n", __func__, mem_size/1024.0);
+    }
+
+    struct ggml_tensor * result = compute(model, allocr);
+
+    float* out_data = new float[ggml_nelements(result)];
+
+    ggml_backend_tensor_get(result, out_data, 0, ggml_nbytes(result));
+
+    printf("\nPerforming ggml_mul_mat test:\n");
+
+    passed = true;
+    for(int i = 0; i < M * N; i++) {
+        if(out_data[i] != expected_result[i]) {
+            passed = false;
+            break;
+        }
+    }
+
+    for (int i = 0; i < M; i++) {
+        for (int j = 0; j < N; j++) {
+            printf("%.1f ", out_data[i * N + j]);
+        }
+        printf("\n");
+    }
+
+    printf("ggml_mul_mat (%d): %s\n", (int) ggml_nelements(result), passed && (ggml_nelements(result) == M * N) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+
+   // free memory
+    ggml_free(model.ctx);
+
+    ggml_backend_buffer_free(model.buffer);
+    ggml_backend_buffer_free(buf_compute);
+    ggml_backend_free(model.backend);
+    return 0;
+}