From: Steward Garcia Date: Sun, 12 Nov 2023 13:34:04 +0000 (-0500) Subject: ggml : replace conv 1D - 2D stage_0 and stage_1 with im2col and mul_mat (#564) X-Git-Tag: upstream/0.0.1642~1199 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ba779f117ecc20f605df6f75dc7d92fb114bf17e;p=pkg%2Fggml%2Fsources%2Fggml ggml : replace conv 1D - 2D stage_0 and stage_1 with im2col and mul_mat (#564) * added conv2d stage 0 - 1 cuda kernels * add im2col + refactor conv1d and conv2d * fix params invalid index * add conv1d and conv2d unit tests * resolving wrong values and fix mul_mat validation * improve tests + reduce code duplication * add cuda kernels * more data test * fix ggml_op_count to 70 * add temp test - gemm != mul_mat * tests : fix test-mul-mat matrix multiplication * test-mul-mat match gemm == ggml_mul_mat with conv2d op * replaced gemm by ggml_mul_mat * ggml_mul_mat cpu backend support fp16 src1 * ggml_mul_mat cuda backend fp16 fixed * remove unnecessary ggml_cont and removed conv1d-2d functions deprecated * some fixes * explain conv1d reshapes * ggml : fix tests on Arm + do not use BLAS for F16 data * tests : fix FP16 handling on Arm * ggml : avoid ggml_cont and ggml_transpose in ggml_conv_xd * ci : switch back to release * cuda : fix wrong pointer usage * ggml : add metal support for im2col and f16xf16 mul mat * ggml : im2col opts * Update src/ggml-cuda.cu Co-authored-by: slaren --------- Co-authored-by: Georgi Gerganov Co-authored-by: slaren --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index e56a8337..52ae6755 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -403,13 +403,8 @@ extern "C" { GGML_OP_ROPE_BACK, GGML_OP_ALIBI, GGML_OP_CLAMP, - GGML_OP_CONV_1D, - GGML_OP_CONV_1D_STAGE_0, // internal - GGML_OP_CONV_1D_STAGE_1, // internal GGML_OP_CONV_TRANSPOSE_1D, - GGML_OP_CONV_2D, - GGML_OP_CONV_2D_STAGE_0, // internal - GGML_OP_CONV_2D_STAGE_1, // internal + GGML_OP_IM2COL, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, @@ -1398,6 +1393,18 @@ extern "C" { float min, float max); + GGML_API struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D); + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index adc34aab..ce4feeec 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -39,6 +39,7 @@ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetMemPool hipDeviceGetMemPool #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -48,6 +49,7 @@ #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree +#define cudaFreeAsync hipFreeAsync #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount @@ -55,6 +57,7 @@ #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaMalloc hipMalloc +#define cudaMallocFromPoolAsync hipMallocFromPoolAsync #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync @@ -63,6 +66,9 @@ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind +#define cudaMemPool_t hipMemPool_t +#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold +#define cudaMemPoolSetAttribute hipMemPoolSetAttribute #define cudaMemset hipMemset #define cudaMemsetAsync hipMemsetAsync #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize @@ -4470,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { *dsti = __float2half(*xi); } +static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { + const half * xi = (const half *) cxi; + half * dsti = (half *) cdsti; + + *dsti = *xi; +} + template static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, @@ -4723,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } +static __global__ void im2col_f32_f16( + const float * x, half * dst, + int ofs0, int ofs1, int IW, int IH, int CHW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; + const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + + const int offset_dst = + (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = __float2half(0.0f); + } else { + const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1; + dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + } +} + template static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); @@ -5612,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } +static void ggml_cpy_f16_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; scale_f32<<>>(x, dst, scale, k); @@ -5695,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c soft_max_f32<<>>(x, dst, ncols_x); } +static void im2col_f32_f16_cuda(const float * x, half * dst, + int OH, int IW, int IH, int OW, int IC, + int KH, int KW, int N, int ofs0, int ofs1, + int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) { + dim3 block_nums(IC, OH, OW); + dim3 block_dims(N, KH, KW); + im2col_f32_f16<<>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -6477,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas( src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } - const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; size_t dst_f16_as = 0; half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); @@ -6653,6 +6704,45 @@ inline void ggml_cuda_op_alibi( (void) src1_dd; } +inline void ggml_cuda_op_im2col( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, + OH, IW, IH, OW, IC, KH, KW, N, + ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream); + + (void) src0; + (void) src0_dd; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7543,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, main_stream); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -7574,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } +void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); +} + static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -7937,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ALIBI: func = ggml_cuda_alibi; break; + case GGML_OP_IM2COL: + func = ggml_cuda_im2col; + break; default: return false; } diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 43d0dff0..148c12b1 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -86,6 +86,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); @@ -114,6 +115,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); + GGML_METAL_DECL_KERNEL(im2col_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); @@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); + GGML_METAL_ADD_KERNEL(im2col_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); @@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); + GGML_METAL_DEL_KERNEL(im2col_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -1030,7 +1036,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1139,6 +1145,7 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F32: { + GGML_ASSERT(src1t == GGML_TYPE_F32); [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; @@ -1146,13 +1153,18 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; + nrows = ne11; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + nrows = 4; + } } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; nrows = 4; } } break; @@ -1342,7 +1354,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -1361,7 +1373,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; + default: GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 7c35f23a..5d1357cd 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32( } } +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, @@ -1229,6 +1302,39 @@ kernel void kernel_rope( template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +kernel void kernel_im2col_f16( + device const float * x, + device half * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/src/ggml.c b/src/ggml.c index 018f0ce0..584ee468 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -143,12 +143,6 @@ void ggml_print_backtrace(void) { } #endif -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 @@ -277,6 +271,12 @@ inline static void * ggml_aligned_malloc(size_t size) { // floating point type used to accumulate sums typedef double ggml_float; +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + // // global data // @@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ROPE_BACK", "ALIBI", "CLAMP", - "CONV_1D", - "CONV_1D_STAGE_0", - "CONV_1D_STAGE_1", "CONV_TRANSPOSE_1D", - "CONV_2D", - "CONV_2D_STAGE_0", - "CONV_2D_STAGE_1", + "IM2COL", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope_back(x)", "alibi(x)", "clamp(x)", - "conv_1d(x)", - "conv_1d_stage_0(x)", - "conv_1d_stage_1(x)", "conv_transpose_1d(x)", - "conv_2d(x)", - "conv_2d_stage_0(x)", - "conv_2d_stage_1(x)", + "im2col(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_GET_ROWS_BACK ] = true; p[GGML_OP_DIAG_MASK_INF ] = true; p[GGML_OP_DIAG_MASK_ZERO ] = true; - p[GGML_OP_CONV_1D ] = true; - p[GGML_OP_CONV_1D_STAGE_0 ] = true; - p[GGML_OP_CONV_1D_STAGE_1 ] = true; p[GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[GGML_OP_CONV_2D ] = true; - p[GGML_OP_CONV_2D_STAGE_0 ] = true; - p[GGML_OP_CONV_2D_STAGE_1 ] = true; p[GGML_OP_CONV_TRANSPOSE_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -// im2col: [N, IC, IL] => [N, OL, IC*K] -// a: [OC,IC, K] -// b: [N, IC, IL] -// result: [N, OL, IC*K] -static struct ggml_tensor * ggml_conv_1d_stage_0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(a->ne[1] == b->ne[1]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - const int64_t ne[4] = { - a->ne[1] * a->ne[0], - OL, - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_1D_STAGE_0; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d_stage_1 - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// a: [OC, IC, K] -// b: [N, OL, IC * K] -// result: [N, OC, OL] -static struct ggml_tensor * ggml_conv_1d_stage_1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - a->ne[2], - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_1D_STAGE_1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d - GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5220,43 +5128,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d( int s0, int p0, int d0) { - struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); - result = ggml_conv_1d_stage_1(ctx, a, result); - return result; -} + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] -// GGML_API struct ggml_tensor * ggml_conv_1d( -// struct ggml_context * ctx, -// struct ggml_tensor * a, -// struct ggml_tensor * b, -// int s0, -// int p0, -// int d0) { -// GGML_ASSERT(ggml_is_matrix(b)); -// GGML_ASSERT(a->ne[1] == b->ne[1]); -// bool is_node = false; + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] -// if (a->grad || b->grad) { -// GGML_ASSERT(false); // TODO: implement backward -// is_node = true; -// } + result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] -// const int64_t ne[4] = { -// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), -// a->ne[2], 1, 1, -// }; -// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - -// int32_t params[] = { s0, p0, d0 }; -// ggml_set_op_params(result, params, sizeof(params)); - -// result->op = GGML_OP_CONV_1D; -// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; -// result->src[0] = a; -// result->src[1] = b; - -// return result; -// } + return result; +} // ggml_conv_1d_ph @@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OH, OW, IC*KH*KW] -static struct ggml_tensor * ggml_conv_2d_stage_0( +struct ggml_tensor * ggml_im2col( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0( int p0, int p1, int d0, - int d1) { + int d1, + bool is_2D) { - GGML_ASSERT(a->ne[2] == b->ne[2]); + if(is_2D) { + GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + GGML_ASSERT(a->ne[1] == b->ne[1]); + } bool is_node = false; if (a->grad || b->grad) { @@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0( is_node = true; } - const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); const int64_t ne[4] = { - a->ne[2] * a->ne[1] * a->ne[0], + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, - OH, - b->ne[3], + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_2D_STAGE_0; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; - -} - -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// a: [OC, IC, KH, KW] -// b: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static struct ggml_tensor * ggml_conv_2d_stage_1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - b->ne[2], - a->ne[3], - b->ne[3], - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_2D_STAGE_1; + result->op = GGML_OP_IM2COL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; return result; - } // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] struct ggml_tensor * ggml_conv_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] - struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW] - result = ggml_conv_2d_stage_1(ctx, a, result); + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] - return result; + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + return result; } // ggml_conv_2d_sk_p0 @@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas( // TODO: find the optimal values for these if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + src0->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ @@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas( } #endif + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat( // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -11637,9 +11497,9 @@ static void ggml_compute_forward_rope_back( } } -// ggml_compute_forward_conv_1d +// ggml_compute_forward_conv_transpose_1d -static void ggml_compute_forward_conv_1d_f16_f32( +static void ggml_compute_forward_conv_transpose_1d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -11656,14 +11516,7 @@ static void ggml_compute_forward_conv_1d_f16_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - - // size of the convolution row - the kernel size unrolled across all input channels - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int nk = ne00*ne01*ne02; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); @@ -11671,23 +11524,37 @@ static void ggml_compute_forward_conv_1d_f16_f32( if (params->type == GGML_TASK_INIT) { memset(params->wdata, 0, params->wsize); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; + // permute source data (src1) from (L x Cin) to (Cin x L) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * dst_data = wdata; - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]); - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + return; } @@ -11695,8 +11562,10 @@ static void ggml_compute_forward_conv_1d_f16_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne2; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -11705,22 +11574,26 @@ static void ggml_compute_forward_conv_1d_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f16(ew0, dst_data + i0, - (ggml_fp16_t *) ((char *) src0->data + i1*nb02), - (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne02, &v, + (ggml_fp16_t *) wdata_src + i1n, + (ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void ggml_compute_forward_conv_1d_f32( +static void ggml_compute_forward_conv_transpose_1d_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -11737,13 +11610,7 @@ static void ggml_compute_forward_conv_1d_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int nk = ne00*ne01*ne02; GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); @@ -11751,23 +11618,37 @@ static void ggml_compute_forward_conv_1d_f32( if (params->type == GGML_TASK_INIT) { memset(params->wdata, 0, params->wsize); - float * const wdata = (float *) params->wdata + 0; + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + float * const wdata = (float *) params->wdata + 0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = src[i10]; } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + return; } @@ -11775,8 +11656,10 @@ static void ggml_compute_forward_conv_1d_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -11785,94 +11668,50 @@ static void ggml_compute_forward_conv_1d_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float * const wdata = (float *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f32(ew0, dst_data + i0, - (float *) ((char *) src0->data + i1*nb02), - (float *) wdata + i2*nb2 + i0*ew0); + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1 -static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, - ggml_fp16_t * A, - ggml_fp16_t * B, - float * C, - const int ith, const int nth) { - // does not seem to make a difference - int64_t m0, m1, n0, n1; - // patches per thread - if (m > n) { - n0 = 0; - n1 = n; - - // total patches in dst - const int np = m; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - m0 = dp*ith; - m1 = MIN(m0 + dp, np); - } else { - m0 = 0; - m1 = m; - - // total patches in dst - const int np = n; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - n0 = dp*ith; - n1 = MIN(n0 + dp, np); - } - - // block-tiling attempt - int64_t blck_n = 16; - int64_t blck_m = 16; - - // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB - // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K); - // if (blck_size > 0) { - // blck_0 = 4; - // blck_1 = blck_size / blck_0; - // if (blck_1 < 0) { - // blck_1 = 1; - // } - // // blck_0 = (int64_t)sqrt(blck_size); - // // blck_1 = blck_0; - // } - // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); - - for (int j = n0; j < n1; j+=blck_n) { - for (int i = m0; i < m1; i+=blck_m) { - // printf("i j k => %d %d %d\n", i, j, K); - for (int ii = i; ii < i + blck_m && ii < m1; ii++) { - for (int jj = j; jj < j + blck_n && jj < n1; jj++) { - ggml_vec_dot_f16(k, - C + ii*n + jj, - A + ii * k, - B + jj * k); - } - } - } +static void ggml_compute_forward_conv_transpose_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; } } -// src0: kernel [OC, IC, K] -// src1: signal [N, IC, IL] -// dst: result [N, OL, IC*K] -static void ggml_compute_forward_conv_1d_stage_0_f32( +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -11886,425 +11725,35 @@ static void ggml_compute_forward_conv_1d_stage_0_f32( GGML_TENSOR_BINARY_OP_LOCALS; - const int64_t N = ne12; - const int64_t IC = ne11; - const int64_t IL = ne10; - - const int64_t K = ne00; - - const int64_t OL = ne1; + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; const int ith = params->ith; const int nth = params->nth; - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // im2col: [N, IC, IL] => [N, OL, IC*K] - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t iol = 0; iol < OL; iol++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] - const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] - - for (int64_t ik = 0; ik < K; ik++) { - const int64_t iil = iol*s0 + ik*d0 - p0; - - if (!(iil < 0 || iil >= IL)) { - dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]); - } - } - } - } - } - } -} - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// src0: [OC, IC, K] -// src1: [N, OL, IC * K] -// result: [N, OC, OL] -static void ggml_compute_forward_conv_1d_stage_1_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne12; - const int OL = ne11; - - const int OC = ne02; - const int IC = ne01; - const int K = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OL; - int64_t k = IC * K; - - // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_0( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_1( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_transpose_1d - -static void ggml_compute_forward_conv_transpose_1d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // permute source data (src1) from (L x Cin) to (Cin x L) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; - ggml_fp16_t * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, ggml_nbytes(dst)); - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f16(ne02, &v, - (ggml_fp16_t *) wdata_src + i1n, - (ggml_fp16_t *) wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + nk; - float * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = src[i10]; - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, ggml_nbytes(dst)); - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - float * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - float * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f32(ne02, &v, - wdata_src + i1n, - wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_2d - -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void ggml_compute_forward_conv_2d_stage_0_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int64_t N = ne13; - const int64_t IC = ne12; - const int64_t IH = ne11; - const int64_t IW = ne10; - - // const int64_t OC = ne03; - // const int64_t IC = ne02; - const int64_t KH = ne01; - const int64_t KW = ne00; - - const int64_t OH = ne2; - const int64_t OW = ne1; - - const int ith = params->ith; - const int nth = params->nth; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); return; } @@ -12317,20 +11766,22 @@ static void ggml_compute_forward_conv_2d_stage_0_f32( ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { + for (int64_t iic = ith; iic < IC; iic += nth) { // micro kernel ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 for (int64_t ikw = 0; ikw < KW; ikw++) { const int64_t iiw = iow*s0 + ikw*d0 - p0; const int64_t iih = ioh*s1 + ikh*d1 - p1; - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); } } @@ -12342,223 +11793,7 @@ static void ggml_compute_forward_conv_2d_stage_0_f32( } } -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// src0: [OC, IC, KH, KW] -// src1: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static void ggml_compute_forward_conv_2d_stage_1_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne13; - const int OH = ne12; - const int OW = ne11; - - const int OC = ne03; - const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_2d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - // src1: image [N, IC, IH, IW] - // src0: kernel [OC, IC, KH, KW] - // dst: result [N, OC, OH, OW] - // ne12: IC - // ne0: OW - // ne1: OH - // nk0: KW - // nk1: KH - // ne13: N - - const int N = ne13; - const int IC = ne12; - const int IH = ne11; - const int IW = ne10; - - const int OC = ne03; - // const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int OH = ne1; - const int OW = ne0; - - const int ith = params->ith; - const int nth = params->nth; - - // const int nk0 = ne00; - // const int nk1 = ne01; - - // size of the convolution row - the kernel size unrolled across all channels - // const int ew0 = nk0*nk1*ne02; - // ew0: IC*KH*KW - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare source data (src1) - // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW] - - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int in = 0; in < N; in++) { - for (int iic = 0; iic < IC; iic++) { - for (int ioh = 0; ioh < OH; ioh++) { - for (int iow = 0; iow < OW; iow++) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] - - for (int ikh = 0; ikh < KH; ikh++) { - for (int ikw = 0; ikw < KW; ikw++) { - const int iiw = iow*s0 + ikw*d0 - p0; - const int iih = ioh*s1 + ikh*d1 - p1; - - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - // wdata: [N*OH*OW, IC*KH*KW] - // dst: result [N, OC, OH, OW] - // src0: kernel [OC, IC, KH, KW] - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m * k] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); - GGML_ASSERT(false); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_2d_stage_0( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_2d_stage_1( +static void ggml_compute_forward_im2col( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12566,7 +11801,7 @@ static void ggml_compute_forward_conv_2d_stage_1( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst); + ggml_compute_forward_im2col_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { @@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_clamp(params, tensor->src[0], tensor); } break; - case GGML_OP_CONV_1D: - { - ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); - } break; case GGML_OP_CONV_TRANSPOSE_1D: { ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); } break; - case GGML_OP_CONV_2D: - { - ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_2D_STAGE_1: + case GGML_OP_IM2COL: { - ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CONV_TRANSPOSE_2D: { @@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_1D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_2D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_2D_STAGE_1: + case GGML_OP_IM2COL: { GGML_ASSERT(false); // TODO: not implemented } break; @@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = 1; //TODO } break; - case GGML_OP_CONV_1D: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - n_tasks = n_threads; - } break; case GGML_OP_CONV_TRANSPOSE_1D: { n_tasks = n_threads; } break; - case GGML_OP_CONV_2D: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_2D_STAGE_1: + case GGML_OP_IM2COL: { n_tasks = n_threads; } break; @@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; default: { + printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op)); GGML_ASSERT(false); } break; } @@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; - case GGML_OP_CONV_1D: - { - GGML_ASSERT(node->src[0]->ne[3] == 1); - GGML_ASSERT(node->src[1]->ne[2] == 1); - GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; - const int64_t ne01 = node->src[0]->ne[1]; - const int64_t ne02 = node->src[0]->ne[2]; - - const int64_t ne10 = node->src[1]->ne[0]; - const int64_t ne11 = node->src[1]->ne[1]; - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t nk = ne00; - const int64_t ew0 = nk * ne01; - - UNUSED(ne02); - UNUSED(ne10); - UNUSED(ne11); - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*(ne0*ne1*ew0); - } else { - GGML_ASSERT(false); - } - } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(node->src[0]->ne[3] == 1); @@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { GGML_ASSERT(false); } } break; - case GGML_OP_CONV_2D: + case GGML_OP_IM2COL: { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // C - const int64_t ne03 = node->src[0]->ne[3]; // N - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // C - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t ne2 = node->ne[2]; - const int64_t ne3 = node->ne[3]; - const int64_t nk = ne00*ne01; - const int64_t ew0 = nk * ne02; - - UNUSED(ne03); - UNUSED(ne2); - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - // im2col: [N*OH*OW, IC*KH*KW] - cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)* (ne10*ne11*ne12); - } else { - GGML_ASSERT(false); - } + n_tasks = n_threads; } break; case GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e0130cda..e069e4e6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -355,3 +355,32 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c) target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + +# +# test-conv1d + +set(TEST_TARGET test-conv1d) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + + +# +# test-conv2d + +set(TEST_TARGET test-conv2d) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + + +# +# test-mul-mat + +set(TEST_TARGET test-mul-mat) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") diff --git a/tests/test-conv1d.cpp b/tests/test-conv1d.cpp new file mode 100644 index 00000000..3067f9df --- /dev/null +++ b/tests/test-conv1d.cpp @@ -0,0 +1,298 @@ +#include "ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +// #define GGML_USE_CUBLAS + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, bool use_gpu = false) { + // create data + int K = 3, IC = 10, OC = 10; + int IL = 8, N = 1; + + // Initialize adata + float* adata = new float[K * IC * OC]; + for (size_t i = 0; i < K * IC * OC; i++) { + adata[i] = 4.5f; + } + + // Convert adata to fp16 format + std::vector hadata(K * IC * OC); + ggml_fp32_to_fp16_row(adata, hadata.data(), K * IC * OC); + + // Initialize bdata + float* bdata = new float[IL * IC * N]; + for (size_t i = 0; i < IL * IC * N; i++) { + bdata[i] = 2.5f; + } + + size_t buffer_size = 0; + { + buffer_size += K * IC * OC * ggml_type_sizef(GGML_TYPE_F16); // tensor a + buffer_size += IL * IC * N * ggml_type_sizef(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, K, IC, OC); + model.b = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, IL, IC, N); + + // create a allocator + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // alloc memory + ggml_allocr_alloc(alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_allocr_alloc(alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata, ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata, 0, ggml_nbytes(model.b)); + } + + ggml_allocr_free(alloc); +} + +struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int p0 = 1; + int d0 = 1; + + // split conv1d in fundamental methods for test unit + struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, 0, p0, 0, d0, 0, false); + ggml_set_name(im2col_0, "im2col_res"); + ggml_build_forward_expand(gf, im2col_0); + + struct ggml_tensor* conv1d_res = ggml_conv_1d(ctx0, model.a, model.b, s0, p0, d0); + ggml_set_name(conv1d_res, "conv1d_res"); + ggml_build_forward_expand(gf, conv1d_res); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph* compute_graph(const test_model & model, struct ggml_allocr * allocr) { + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = build_graph(model, allocr); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + // in this case, the output tensor is the last one in the graph + return gf; +} + +int main(void) +{ + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_backend_buffer_t buf_compute; // for compute + struct ggml_allocr * allocr = NULL; + + { + size_t align = ggml_backend_get_alignment(model.backend); + allocr = ggml_allocr_new_measure(align); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model, allocr); + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf); + ggml_allocr_free(allocr); + + // compute the required memory + buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size); + allocr = ggml_allocr_new_from_buffer(buf_compute); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + struct ggml_cgraph * gf_res = compute_graph(model, allocr); + + struct ggml_tensor * im2col_res = NULL; + struct ggml_tensor * conv1d_res = NULL; + + for(int i = 0; i < gf_res->n_nodes; i++) { + if(strcmp(ggml_get_name(gf_res->nodes[i]), "im2col_res") == 0) { + im2col_res = gf_res->nodes[i]; + } else if(strcmp(ggml_get_name(gf_res->nodes[i]), "conv1d_res") == 0) { + conv1d_res = gf_res->nodes[i]; + } + } + + uint16_t* im2col_data = new uint16_t[ggml_nelements(im2col_res)]; + float* conv2d_data = new float[ggml_nelements(conv1d_res)]; + + ggml_backend_tensor_get(im2col_res, im2col_data, 0, ggml_nbytes(im2col_res)); + ggml_backend_tensor_get(conv1d_res, conv2d_data, 0, ggml_nbytes(conv1d_res)); + + const int n_conv1d_test = 80; + const int n_im2col_test = 240; + + float expected_conv1d[n_conv1d_test] = { + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f + }; + // first im2col test + + uint16_t expected_im2col[n_conv1d_test] = { + 0, 16640, 16640, 0, 16640, 16640, 0, 16640, + 16640, 0, 16640, 16640, 0, 16640, 16640, 0, + 16640, 16640, 0, 16640, 16640, 0, 16640, 16640, + 0, 16640, 16640, 0, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640, + 16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640 + }; + + printf("\nPerforming test:\n"); + + bool passed = true; + for(int i = 0; i < n_conv1d_test; i++) { + if( + im2col_data[i] != expected_im2col[i]) { + passed = false; + break; + } + } + + printf("ggml_im2col (%i): %s\n", ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + passed = true; + for(int i = 0; i < n_conv1d_test; i++) { + if(conv2d_data[i] != expected_conv1d[i]) { + passed = false; + break; + } + } + + printf("ggml_conv1d (%i): %s\n", ggml_nelements(conv1d_res), passed && (ggml_nelements(conv1d_res) == n_conv1d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_buffer_free(buf_compute); + ggml_backend_free(model.backend); + return 0; +} diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp new file mode 100644 index 00000000..4ad830b4 --- /dev/null +++ b/tests/test-conv2d.cpp @@ -0,0 +1,401 @@ +#include "ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +// #define GGML_USE_CUBLAS + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, bool use_gpu = false) { + // create data + int KW = 3, KH = 3, IC = 10, OC = 10; + int IW = 8, IH = 6, N = 1; + + // Initialize adata + float* adata = new float[KW * KH * IC * OC]; + for (size_t i = 0; i < KW * KH * IC * OC; i++) { + adata[i] = 2.5f; + } + + // Convert adata to fp16 format + std::vector hadata(KW * KH * IC * OC); + ggml_fp32_to_fp16_row(adata, hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + float* bdata = new float[IW * IH * IC * N]; + for (size_t i = 0; i < IW * IH * IC * N; i++) { + bdata[i] = 1.5f; + } + + size_t buffer_size = 0; + { + buffer_size += KW * KH * IC * OC * ggml_type_sizef(GGML_TYPE_F16); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_sizef(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // alloc memory + ggml_allocr_alloc(alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_allocr_alloc(alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata, ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata, 0, ggml_nbytes(model.b)); + } + + ggml_allocr_free(alloc); +} + +struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + // split conv2d in fundamental methods for test unit + struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true); + ggml_set_name(im2col_0, "im2col_res"); + ggml_build_forward_expand(gf, im2col_0); + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * compute_graph(const test_model & model, struct ggml_allocr * allocr) { + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = build_graph(model, allocr); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + // in this case, the output tensor is the last one in the graph + return gf; +} + +int main(void) +{ + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_backend_buffer_t buf_compute; // for compute + struct ggml_allocr * allocr = NULL; + + { + size_t align = ggml_backend_get_alignment(model.backend); + allocr = ggml_allocr_new_measure(align); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model, allocr); + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf); + ggml_allocr_free(allocr); + + // compute the required memory + buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size); + allocr = ggml_allocr_new_from_buffer(buf_compute); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + struct ggml_cgraph * gf_res = compute_graph(model, allocr); + + struct ggml_tensor * im2col_res = NULL; + struct ggml_tensor * conv2d_res = NULL; + + for(int i = 0; i < gf_res->n_nodes; i++) { + if(strcmp(ggml_get_name(gf_res->nodes[i]), "im2col_res") == 0) { + im2col_res = gf_res->nodes[i]; + } else if(strcmp(ggml_get_name(gf_res->nodes[i]), "conv2d_res") == 0) { + conv2d_res = gf_res->nodes[i]; + } + } + + uint16_t* im2col_data = new uint16_t[ggml_nelements(im2col_res)]; + float* conv2d_data = new float[ggml_nelements(conv2d_res)]; + + ggml_backend_tensor_get(im2col_res, im2col_data, 0, ggml_nbytes(im2col_res)); + ggml_backend_tensor_get(conv2d_res, conv2d_data, 0, ggml_nbytes(conv2d_res)); + + const int n_conv2d_test = 480; + const int n_im2col_test = 4320; + + float expected_conv2d [n_conv2d_test] = { + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f }; + + uint16_t expected_im2col[n_conv2d_test] = { + 0, 0, 0, 0, 15872, 15872, 0, 15872, + 15872, 0, 0, 0, 0, 15872, 15872, 0, + 15872, 15872, 0, 0, 0, 0, 15872, 15872, + 0, 15872, 15872, 0, 0, 0, 0, 15872, + 15872, 0, 15872, 15872, 0, 0, 0, 0, + 15872, 15872, 0, 15872, 15872, 0, 0, 0, + 0, 15872, 15872, 0, 15872, 15872, 0, 0, + 0, 0, 15872, 15872, 0, 15872, 15872, 0, + 0, 0, 0, 15872, 15872, 0, 15872, 15872, + 0, 0, 0, 0, 15872, 15872, 0, 15872, + 15872, 0, 0, 0, 0, 15872, 15872, 0, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0 + }; + + printf("\nPerforming test:\n"); + + bool passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if( + im2col_data[i] != expected_im2col[i]) { + passed = false; + break; + } + } + + printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if(conv2d_data[i] != expected_conv2d[i]) { + passed = false; + break; + } + } + + printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_buffer_free(buf_compute); + ggml_backend_free(model.backend); + return 0; +} diff --git a/tests/test-mul-mat.cpp b/tests/test-mul-mat.cpp new file mode 100644 index 00000000..36b7c6bc --- /dev/null +++ b/tests/test-mul-mat.cpp @@ -0,0 +1,363 @@ +#include "ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +//#define GGML_USE_CUBLAS // uncomment this to use cuda backend, make sure build ggml lib with GGML_CUBLAS=ON + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, float* a, float* b, int M, int N, int K, bool use_gpu = false) { + size_t buffer_size = 0; + { + buffer_size += (M * N) * ggml_type_sizef(GGML_TYPE_F32); // tensor a + buffer_size += (N * K) * ggml_type_sizef(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %d bytes\n", __func__, (int) buffer_size); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, M); + printf("Matrix A: [%i, %i]\n", K, M); + model.b = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, N); + printf("Matrix B: [%i, %i]\n", K, N); + + // create a allocator + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // alloc memory + ggml_allocr_alloc(alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.a->data, a, ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, a, 0, ggml_nbytes(model.a)); // cuda requires copy the data directly to device + } + + // alloc memory + ggml_allocr_alloc(alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, b, ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, b, 0, ggml_nbytes(model.b)); // cuda requires copy the data directly to device + } + + ggml_allocr_free(alloc); +} + +struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // zT = x @ yT + struct ggml_tensor * result = ggml_mul_mat(ctx0, model.a, ggml_cont(ctx0, model.b)); + + // z = (zT)T + ggml_build_forward_expand(gf, ggml_cont(ctx0, ggml_transpose(ctx0, result))); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_tensor* compute(const test_model & model, struct ggml_allocr * allocr) { + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = build_graph(model, allocr); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + // in this case, the output tensor is the last one in the graph + return gf->nodes[gf->n_nodes - 1]; +} + + +static void ggml_vec_dot_f16(const int n, float * s, float * x, float * y) { + float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += x[i] * y[i]; + } + *s = sumf; +} + +static void gemm_f16_out_f32(int m, int n, int k, + float * A, + float * B, + float * C, + const int ith, const int nth) { + // does not seem to make a difference + int m0, m1, n0, n1; + // patches per thread + if (m > n) { + n0 = 0; + n1 = n; + + // total patches in dst + const int np = m; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + m0 = dp*ith; + m1 = std::min(m0 + dp, np); + } else { + m0 = 0; + m1 = m; + + // total patches in dst + const int np = n; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + n0 = dp*ith; + n1 = std::min(n0 + dp, np); + } + + // block-tiling attempt + int64_t blck_n = 16; + int64_t blck_m = 16; + + for (int j = n0; j < n1; j+=blck_n) { + for (int i = m0; i < m1; i+=blck_m) { + // printf("i j k => %d %d %d\n", i, j, K); + for (int ii = i; ii < i + blck_m && ii < m1; ii++) { + for (int jj = j; jj < j + blck_n && jj < n1; jj++) { + ggml_vec_dot_f16(k, + C + ii*n + jj, + A + ii * k, + B + jj * k); + } + } + } + } +} + + +void perform_gemm_test(float* a, float* b, float* expected, int M, int N, int K) { + printf("\nPerforming gemm_f16_out_f32 test:\n"); + + float* gemm_out = new float[M * N]; + gemm_f16_out_f32(M, N, K, a, b, gemm_out, 0, 1); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.1ff,", gemm_out[i * N + j]); + } + printf("\n"); + } + + bool passed = true; + + for(int i = 0; i < M * N; i++) { + if(gemm_out[i] != expected[i]) { + passed = false; + break; + } + } + + printf("gemm_mult (%i): %s\n", (M * N), passed ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); +} + +int main(void) +{ + ggml_time_init(); + const int M = 4, N = 16, K = 36; // a conv2d expected matrix multiplication + + // matrix A (4 X 36) + float matrixA[M * K] = { + 2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f, + 10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f, + 7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, + }; + + // matrix B (16 X 36) + float matrixB[N * K] = { + 9.0f, 7.0f, 1.0f, 3.0f, 5.0f, 9.0f, 7.0f, 6.0f, 1.0f, 10.0f, 1.0f, 1.0f, 7.0f, 2.0f, 4.0f, 9.0f, 10.0f, 4.0f, 5.0f, 5.0f, 7.0f, 1.0f, 7.0f, 7.0f, 2.0f, 9.0f, 5.0f, 10.0f, 7.0f, 4.0f, 8.0f, 9.0f, 9.0f, 3.0f, 10.0f, 2.0f, + 4.0f, 6.0f, 10.0f, 9.0f, 5.0f, 1.0f, 8.0f, 7.0f, 4.0f, 7.0f, 2.0f, 6.0f, 5.0f, 3.0f, 1.0f, 10.0f, 8.0f, 4.0f, 8.0f, 3.0f, 7.0f, 1.0f, 2.0f, 7.0f, 6.0f, 8.0f, 6.0f, 5.0f, 2.0f, 3.0f, 1.0f, 1.0f, 2.0f, 5.0f, 7.0f, 1.0f, + 8.0f, 2.0f, 8.0f, 8.0f, 8.0f, 8.0f, 4.0f, 4.0f, 6.0f, 10.0f, 10.0f, 9.0f, 2.0f, 9.0f, 3.0f, 7.0f, 7.0f, 1.0f, 4.0f, 9.0f, 1.0f, 2.0f, 3.0f, 6.0f, 1.0f, 10.0f, 5.0f, 8.0f, 9.0f, 4.0f, 6.0f, 2.0f, 3.0f, 1.0f, 2.0f, 7.0f, + 5.0f, 1.0f, 7.0f, 2.0f, 9.0f, 10.0f, 9.0f, 5.0f, 2.0f, 5.0f, 4.0f, 10.0f, 9.0f, 9.0f, 1.0f, 9.0f, 8.0f, 8.0f, 9.0f, 4.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 8.0f, 4.0f, 5.0f, 10.0f, 7.0f, 6.0f, 2.0f, 1.0f, 10.0f, 10.0f, 7.0f, + 9.0f, 4.0f, 5.0f, 9.0f, 5.0f, 10.0f, 10.0f, 3.0f, 6.0f, 6.0f, 4.0f, 4.0f, 4.0f, 8.0f, 5.0f, 4.0f, 9.0f, 1.0f, 9.0f, 9.0f, 1.0f, 7.0f, 9.0f, 2.0f, 10.0f, 9.0f, 10.0f, 8.0f, 3.0f, 3.0f, 9.0f, 3.0f, 9.0f, 10.0f, 1.0f, 8.0f, + 9.0f, 2.0f, 6.0f, 9.0f, 7.0f, 2.0f, 3.0f, 5.0f, 3.0f, 6.0f, 9.0f, 7.0f, 3.0f, 7.0f, 6.0f, 4.0f, 10.0f, 3.0f, 5.0f, 7.0f, 2.0f, 9.0f, 3.0f, 2.0f, 2.0f, 10.0f, 8.0f, 7.0f, 3.0f, 10.0f, 6.0f, 3.0f, 1.0f, 1.0f, 4.0f, 10.0f, + 2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f, + 10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f, + 7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, + 6.0f, 2.0f, 3.0f, 3.0f, 3.0f, 7.0f, 5.0f, 1.0f, 8.0f, 1.0f, 4.0f, 5.0f, 1.0f, 1.0f, 6.0f, 4.0f, 2.0f, 1.0f, 7.0f, 8.0f, 6.0f, 1.0f, 1.0f, 5.0f, 6.0f, 5.0f, 10.0f, 6.0f, 7.0f, 5.0f, 9.0f, 3.0f, 2.0f, 7.0f, 9.0f, 4.0f, + 2.0f, 5.0f, 9.0f, 5.0f, 10.0f, 3.0f, 1.0f, 8.0f, 1.0f, 7.0f, 1.0f, 8.0f, 1.0f, 6.0f, 7.0f, 8.0f, 4.0f, 9.0f, 5.0f, 10.0f, 3.0f, 7.0f, 6.0f, 8.0f, 8.0f, 5.0f, 6.0f, 8.0f, 10.0f, 9.0f, 4.0f, 1.0f, 3.0f, 3.0f, 4.0f, 7.0f, + 8.0f, 2.0f, 6.0f, 6.0f, 5.0f, 1.0f, 3.0f, 7.0f, 1.0f, 7.0f, 2.0f, 2.0f, 2.0f, 8.0f, 4.0f, 1.0f, 1.0f, 5.0f, 9.0f, 4.0f, 1.0f, 2.0f, 3.0f, 10.0f, 1.0f, 4.0f, 9.0f, 9.0f, 6.0f, 8.0f, 8.0f, 1.0f, 9.0f, 10.0f, 4.0f, 1.0f, + 8.0f, 5.0f, 8.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 1.0f, 9.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 5.0f, 6.0f, 7.0f, 3.0f, 1.0f, 4.0f, 6.0f, 7.0f, 7.0f, 7.0f, 8.0f, 7.0f, 8.0f, 8.0f, 2.0f, 10.0f, 2.0f, 7.0f, 3.0f, 8.0f, 3.0f, + 8.0f, 7.0f, 6.0f, 2.0f, 4.0f, 10.0f, 10.0f, 6.0f, 10.0f, 3.0f, 7.0f, 6.0f, 4.0f, 3.0f, 5.0f, 5.0f, 5.0f, 3.0f, 8.0f, 10.0f, 3.0f, 4.0f, 8.0f, 4.0f, 2.0f, 6.0f, 8.0f, 9.0f, 6.0f, 9.0f, 4.0f, 3.0f, 5.0f, 2.0f, 2.0f, 6.0f, + 10.0f, 6.0f, 2.0f, 1.0f, 7.0f, 5.0f, 6.0f, 4.0f, 1.0f, 9.0f, 10.0f, 2.0f, 4.0f, 5.0f, 8.0f, 5.0f, 7.0f, 4.0f, 7.0f, 6.0f, 3.0f, 9.0f, 2.0f, 1.0f, 4.0f, 2.0f, 6.0f, 6.0f, 3.0f, 3.0f, 2.0f, 8.0f, 5.0f, 9.0f, 3.0f, 4.0f, + }; + + // matrix C (4 x 16) + float expected_result[M * N] = { + 1224.0f, 1023.0f, 1158.0f,1259.0f,1359.0f,1194.0f,1535.0f,1247.0f,1185.0f,1029.0f,889.0f,1182.0f,955.0f,1179.0f,1147.0f,1048.0f, + 1216.0f, 1087.0f, 1239.0f,1361.0f,1392.0f,1260.0f,1247.0f,1563.0f,1167.0f,1052.0f,942.0f,1214.0f,1045.0f,1134.0f,1264.0f,1126.0f, + 1125.0f, 966.0f, 1079.0f,1333.0f,1287.0f,1101.0f,1185.0f,1167.0f,1368.0f,990.0f,967.0f,1121.0f,971.0f,1086.0f,1130.0f,980.0f, + 999.0f, 902.0f, 1020.0f,1056.0f,1076.0f,929.0f,1029.0f,1052.0f,990.0f,1108.0f,823.0f,989.0f,759.0f,1041.0f,1003.0f,870.0f + }; + + bool passed = true; + + perform_gemm_test(matrixA, matrixB, expected_result, M, N, K); + + test_model model; + load_model(model, matrixA, matrixB, M, N, K, true); + + ggml_backend_buffer_t buf_compute; // for compute + struct ggml_allocr * allocr = NULL; + + { + size_t align = ggml_backend_get_alignment(model.backend); + allocr = ggml_allocr_new_measure(align); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model, allocr); + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf); + ggml_allocr_free(allocr); + + // compute the required memory + buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size); + allocr = ggml_allocr_new_from_buffer(buf_compute); + fprintf(stderr, "%s: compute buffer size: %.4f KB\n", __func__, mem_size/1024.0); + } + + struct ggml_tensor * result = compute(model, allocr); + + float* out_data = new float[ggml_nelements(result)]; + + ggml_backend_tensor_get(result, out_data, 0, ggml_nbytes(result)); + + printf("\nPerforming ggml_mul_mat test:\n"); + + passed = true; + for(int i = 0; i < M * N; i++) { + if(out_data[i] != expected_result[i]) { + passed = false; + break; + } + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.1f ", out_data[i * N + j]); + } + printf("\n"); + } + + printf("ggml_mul_mat (%d): %s\n", (int) ggml_nelements(result), passed && (ggml_nelements(result) == M * N) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + // free memory + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_buffer_free(buf_compute); + ggml_backend_free(model.backend); + return 0; +}