From: leejet Date: Thu, 4 Sep 2025 08:38:49 +0000 (+0800) Subject: ggml: add ops for WAN video model (cuda && cpu) (#15669) X-Git-Tag: upstream/0.0.6527~149 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0a1b3982cd0bd18730d50a693053b88c13fd04a6;p=pkg%2Fggml%2Fsources%2Fllama.cpp ggml: add ops for WAN video model (cuda && cpu) (#15669) * add conv3d support * add ggml_pad_ext for cpu & cuda backend * cuda/cpu: add im2col_3d support * cuda: make im2col a little faster * fix cuda pad/scale/im2col3d * make im2col_3d faster * gguf: support loading tensors which n_dims > GGML_MAX_DIMS * fix cuda get_rows * avoid ggml_conv_3d conflict * correct GGML_OP_COUNT assertion * avoid build failure * avoid build failure on MacOS * cuda: remove unnecessary MIN define * fix cpu im2col_3d * adjust the code style * cuda: use simpler loop in get_rows * add test_im2col_3d to test-backend-ops * test-backend-ops.cpp: remove trailing whitespace * cpu: im2col_3d support non continuous src Co-authored-by: Jeff Bolz * fix test_im2col_3d * remove unused variables * cuda: get_rows: dfloat2 -> float2 * add test_pad_ext to test-backend-ops.cpp * add gguf_init_from_file_ext impl * Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS" This reverts commit d8377a0a37f314bd3713fe043b4333ad661610c1. * Revert "add gguf_init_from_file_ext impl" This reverts commit d9f1d13208c68ef83b3538201ac7f31614fb1994. * update ggml_backend_vk_device_supports_op * fix ggml_backend_vk_device_supports_op * update other backend supports op for ggml_pad_ext * metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op --------- Co-authored-by: Jeff Bolz --- diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c..c01b98ac 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -511,6 +511,7 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_IM2COL_3D, GGML_OP_CONV_2D, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, @@ -1870,6 +1871,41 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type); + + // a: [OC*IC, KD, KH, KW] + // b: [N*IC, ID, IH, IW] + // result: [N*OC, OD, OH, OW] + GGML_API struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ); + // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1941,7 +1977,7 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - GGML_API struct ggml_tensor * ggml_conv_3d( + GGML_API struct ggml_tensor * ggml_conv_3d_direct( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] struct ggml_tensor * b, // input [W, H, D, C * N] @@ -2048,6 +2084,19 @@ extern "C" { int p2, int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] GGML_API struct ggml_tensor * ggml_pad_reflect_1d( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2d81fbd5..ac2e2e1a 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -589,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // the position of elements in the array means which dirction to padding, // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind, // dim2.front, dim2.behind, dim3.front, dim3.behind] - int64_t paddings[] = { - 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1], - 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]}; + const int32_t lp0 = ggml_get_op_params_i32(dst, 0); + const int32_t rp0 = ggml_get_op_params_i32(dst, 1); + const int32_t lp1 = ggml_get_op_params_i32(dst, 2); + const int32_t rp1 = ggml_get_op_params_i32(dst, 3); + const int32_t lp2 = ggml_get_op_params_i32(dst, 4); + const int32_t rp2 = ggml_get_op_params_i32(dst, 5); + const int32_t lp3 = ggml_get_op_params_i32(dst, 6); + const int32_t rp3 = ggml_get_op_params_i32(dst, 7); + + int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3}; aclnn_pad(ctx, acl_src, acl_dst, paddings); ggml_cann_release_resources(ctx, acl_src, acl_dst); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 78ec189d..0d35d933 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1876,6 +1876,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_IM2COL_3D: + { + ggml_compute_forward_im2col_3d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); @@ -2255,6 +2259,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8c1f7948..0bb767e0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32( } } + +// ggml_compute_forward_im2col_3d_f16 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s); + } + } + } + } + } + } + } + } + } + } +} + +// ggml_compute_forward_im2col_3d_f32 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s; + } + } + } + } + } + } + } + } + } + } +} + + +void ggml_compute_forward_im2col_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_im2col_3d_f16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_im2col_3d_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, void * a, void * b, float * c) { const ggml_type_traits * traits = ggml_get_type_traits(type); @@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32( GGML_TENSOR_UNARY_OP_LOCALS float * dst_ptr = (float *) dst->data; + const int32_t lp0 = ggml_get_op_params_i32(dst, 0); + const int32_t rp0 = ggml_get_op_params_i32(dst, 1); + const int32_t lp1 = ggml_get_op_params_i32(dst, 2); + const int32_t rp1 = ggml_get_op_params_i32(dst, 3); + const int32_t lp2 = ggml_get_op_params_i32(dst, 4); + const int32_t rp2 = ggml_get_op_params_i32(dst, 5); + const int32_t lp3 = ggml_get_op_params_i32(dst, 6); + const int32_t rp3 = ggml_get_op_params_i32(dst, 7); + // TODO: optimize @@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32( for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i3 = 0; i3 < ne3; ++i3) { const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - - const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + if ((i0 >= lp0 && i0 < ne0 - rp0) \ + && (i1 >= lp1 && i1 < ne1 - rp1) \ + && (i2 >= lp2 && i2 < ne2 - rp2) \ + && (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); dst_ptr[dst_idx] = *src_ptr; } else { dst_ptr[dst_idx] = 0; diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index d0ea8384..9824a03b 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 3ec0e957..83d02474 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -2,6 +2,8 @@ #include "dequantize.cuh" #include "convert.cuh" +#define MAX_GRIDDIM_Y 65535 + template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, @@ -11,32 +13,29 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; - if (i00 >= ne00) { - return; - } + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); - // dequantize - float2 v; - dequantize_kernel(src0_row, ib, iqs, v); - - dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); - dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + } } template @@ -48,22 +47,23 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = blockIdx.y * blockDim.x + threadIdx.x; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; - if (i00 >= ne00) { - return; - } + if (i00 >= ne00) { + return; + } - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = ggml_cuda_cast(src0_row[i00]); + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); + } } template @@ -98,7 +98,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -131,7 +131,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e06f95f0..0c01eb6f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_cuda_op_im2col_3d(ctx, dst); + break; case GGML_OP_CONV_2D: ggml_cuda_op_conv2d(ctx, dst); break; @@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 16bb9bec..7737d6a5 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } } + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static __global__ void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= IC_KD_KH_KW) { + return; + } + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw; + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_cuda(const float * src, T* dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, + IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, + OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + s0, s1, s2, p0, p1, p2, d0, d1, d2); +} + +static void im2col_3d_cuda_f16(const float * src, half * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_cuda_f32(const float * src, float * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + if(dst->type == GGML_TYPE_F16) { + im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } else { + im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } +} diff --git a/ggml/src/ggml-cuda/im2col.cuh b/ggml/src/ggml-cuda/im2col.cuh index 1ce8fae4..2da1223d 100644 --- a/ggml/src/ggml-cuda/im2col.cuh +++ b/ggml/src/ggml-cuda/im2col.cuh @@ -3,3 +3,4 @@ #define CUDA_IM2COL_BLOCK_SIZE 256 void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 77432b04..29aef33c 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -1,36 +1,50 @@ #include "pad.cuh" -static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { - // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 - // blockIdx.y: idx of ne1 - // blockIDx.x: idx of ne0 / BLOCK_SIZE - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { +static __global__ void pad_f32(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3) { + // blockIdx.z: i3*ne2+i2 + // blockIdx.y: i1 + // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE + // gridDim.y: ne1 + int i0 = threadIdx.x + blockIdx.x * blockDim.x; + int i1 = blockIdx.y; + int i2 = blockIdx.z % ne2; + int i3 = blockIdx.z / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } // operation - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * ne01; - dst[offset_dst] = x[offset_src]; + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; + + dst[dst_idx] = src[src_idx]; } else { - dst[offset_dst] = 0.0f; + dst[dst_idx] = 0.0f; } } -static void pad_f32_cuda(const float * x, float * dst, - const int ne00, const int ne01, const int ne02, const int ne03, +static void pad_f32_cuda(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2*ne3); - pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); + pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); } void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; pad_f32_cuda(src0_d, dst_d, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 2ee9e588..0ddeff6a 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,19 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; - } +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; - dst[i] = scale * x[i] + bias; + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; + } } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, k); +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 3d16a1dc..9b4006d9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1886,7 +1886,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index fc54c90f..727163b7 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2701,7 +2701,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1; + op->src[0]->ne[3] == 1 && op->ne[3] == 1 && + (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 18ff4e0b..877fbf7e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4398,7 +4398,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return ggml_is_contiguous(op->src[0]); case GGML_OP_POOL_2D: case GGML_OP_ACC: + return true; case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cd5af1c..cd1c66ba 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12076,7 +12076,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ACC: case GGML_OP_CONCAT: case GGML_OP_SCALE: + return true; case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_ROLL: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f..f35c3379 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -974,6 +974,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "IM2COL_3D", "CONV_2D", "CONV_3D", "CONV_2D_DW", @@ -1018,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "im2col_3d(x)", "conv_2d(x)", "conv_3d(x)", "conv_2d_dw(x)", @@ -1121,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4361,6 +4363,91 @@ struct ggml_tensor * ggml_conv_2d( return result; } +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OD, OH, OW, IC * KD * KH * KW] +struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type) { + const int64_t N = b->ne[3] / IC; + const int64_t ID = b->ne[2]; + const int64_t IH = b->ne[1]; + const int64_t IW = b->ne[0]; + + const int64_t OC = a->ne[3] / IC; + UNUSED(OC); + const int64_t KD = a->ne[2]; + const int64_t KH = a->ne[1]; + const int64_t KW = a->ne[0]; + const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2); + const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1); + const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0); + + GGML_ASSERT((OD > 0) && "b too small compared to a"); + GGML_ASSERT((OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + + + const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N}; + + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OC, OD, OH, OW] +struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ) { + struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW] + + int64_t OC = a->ne[3] / IC; + int64_t N = b->ne[3] / IC; + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW] + + int64_t OD = im2col->ne[3] / N; + result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW] + + return result; +} + // ggml_conv_2d_sk_p0 struct ggml_tensor * ggml_conv_2d_sk_p0( @@ -4482,9 +4569,9 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } -// ggml_conv_3d +// ggml_conv_3d_direct -struct ggml_tensor * ggml_conv_3d( +struct ggml_tensor * ggml_conv_3d_direct( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -4710,11 +4797,36 @@ struct ggml_tensor * ggml_pad( int p1, int p2, int p3) { + return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); +} + +struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ) { struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - a->ne[0] + p0, - a->ne[1] + p1, - a->ne[2] + p2, - a->ne[3] + p3); + a->ne[0] + lp0 + rp0, + a->ne[1] + lp1 + rp1, + a->ne[2] + lp2 + rp2, + a->ne[3] + lp3 + rp3); + + ggml_set_op_params_i32(result, 0, lp0); + ggml_set_op_params_i32(result, 1, rp0); + ggml_set_op_params_i32(result, 2, lp1); + ggml_set_op_params_i32(result, 3, rp1); + ggml_set_op_params_i32(result, 4, lp2); + ggml_set_op_params_i32(result, 5, rp2); + ggml_set_op_params_i32(result, 6, lp3); + ggml_set_op_params_i32(result, 7, rp3); + result->op = GGML_OP_PAD; result->src[0] = a; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a586210..89b812f1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -297,6 +297,8 @@ static std::string var_to_str(ggml_scale_mode mode) { #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k) #define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l) #define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m) +#define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + "," + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n) +#define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + "," + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o) #ifdef GGML_USE_SYCL static bool inline _isinf(float f) { @@ -4023,6 +4025,56 @@ struct test_im2col : public test_case { } }; +// GGML_OP_IM2COL_3D +struct test_im2col_3d : public test_case { + const ggml_type type_input; + const ggml_type type_kernel; + const ggml_type dst_type; + const std::array ne_input; + const std::array ne_kernel; + // stride + const int s0; + const int s1; + const int s2; + // padding + const int p0; + const int p1; + const int p2; + // dilation + const int d0; + const int d1; + const int d2; + + const int64_t IC; + + std::string vars() override { + return VARS_TO_STR15(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); + } + + test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, + std::array ne_input = {10, 10, 10, 9}, // [OC*IC, KD, KH, KW] + std::array ne_kernel = {3, 3, 3, 1}, // [N*IC, ID, IH, IW] + int64_t IC = 3, + int s0 = 1, int s1 = 1, int s2 = 1, + int p0 = 1, int p1 = 1, int p2 = 1, + int d0 = 1, int d1 = 1, int d2 = 1) + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_param(input); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + ggml_tensor * out = ggml_im2col_3d(ctx, kernel, input, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, dst_type); + ggml_set_name(out, "out"); + + return out; + } +}; + // CONV_2D struct test_conv_2d : public test_case { const std::array ne_input; @@ -4221,7 +4273,7 @@ struct test_conv_3d : public test_case { ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel); ggml_set_name(kernel, "kernel"); - ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC); + ggml_tensor * out = ggml_conv_3d_direct(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC); ggml_set_name(out, "out"); return out; } @@ -4640,6 +4692,39 @@ struct test_pad : public test_case { } }; +struct test_pad_ext : public test_case { + const ggml_type type; + const std::array ne_a; + const int lp0; + const int rp0; + const int lp1; + const int rp1; + const int lp2; + const int rp2; + const int lp3; + const int rp3; + + std::string vars() override { + return VARS_TO_STR10(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + + test_pad_ext(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {512, 512, 3, 1}, + int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1, + int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1) + : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_PAD_REFLECT_1D struct test_pad_reflect_1d : public test_case { const ggml_type type; @@ -5623,6 +5708,32 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true)); + // im2col 3D + test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + for (int s0 : {1, 3}) { + for (int s1 : {1, 3}) { + for (int s2 : {1, 3}) { + for (int p0 : {0, 3}) { + for (int p1 : {0, 3}) { + for (int p2 : {0, 3}) { + for (int d0 : {1, 3}) { + for (int d1 : {1, 3}) { + for (int d2 : {1, 3}) { + test_cases.emplace_back(new test_im2col_3d( + GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3}, + 3, s0, s1, s2, p0, p1, p2, d0, d1, d2)); + } + } + } + } + } + } + } + } + } + // Conv_2D test cases #ifdef DETAILED_TESTS // Probably we do not have enough time to execute these in the pipeline. @@ -6340,6 +6451,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); + test_cases.emplace_back(new test_pad_ext()); test_cases.emplace_back(new test_pad_reflect_1d()); test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange());