#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
+#define CUDA_POOL2D_BLOCK_SIZE 256
#define CUDA_Q8_0_NE_ALIGN 2048
dst[i] = fmaxf(x[i], 0);
}
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
}
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
- const int row = blockIdx.y;
+ const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
-static __global__ void im2col_f32_f16(
- const float * x, half * dst,
- int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
+template <typename T>
+static __global__ void im2col_kernel(
+ const float * x, T * dst, int batch_offset,
+ int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
const int ky = (i - kd) / OW;
const int ix = i % OW;
+ const int oh = blockIdx.y;
+ const int batch = blockIdx.z / IC;
+ const int ic = blockIdx.z % IC;
+
const int64_t iiw = ix * s0 + kx * d0 - p0;
- const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
+ const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
- (blockIdx.y * OW + ix) * CHW +
- (blockIdx.z * (KW * KH) + ky * KW + kx);
+ ((batch * OH + oh) * OW + ix) * CHW +
+ (ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst[offset_dst] = __float2half(0.0f);
+ dst[offset_dst] = 0.0f;
} else {
- const int64_t offset_src = blockIdx.z * offset_delta;
- dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+template <typename Ti, typename To>
+static __global__ void pool2d_nchw_kernel(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx >= parallel_elements) {
+ return;
+ }
+
+ const int I_HW = ih * iw;
+ const int O_HW = oh * ow;
+ const int nc = idx / O_HW;
+ const int cur_oh = idx % O_HW / ow;
+ const int cur_ow = idx % O_HW % ow;
+ const Ti* i_ptr = src + nc * I_HW;
+ To* o_ptr = dst + nc * O_HW;
+ const int start_h = cur_oh * sh - ph;
+ const int bh = max(0, start_h);
+ const int eh = min(ih, start_h + kh);
+ const int start_w = cur_ow * sw - pw;
+ const int bw = max(0, start_w);
+ const int ew = min(iw, start_w + kw);
+ const To scale = 1. / (kh * kw);
+ To res = 0;
+
+ switch (op) {
+ case GGML_OP_POOL_AVG: res = 0; break;
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+ }
+
+ for (int i = bh; i < eh; i += 1) {
+ for (int j = bw; j < ew; j += 1) {
+ #if __CUDA_ARCH__ >= 350
+ Ti cur = __ldg(i_ptr + i * iw + j);
+ #else
+ Ti cur = i_ptr[i * iw + j];
+ #endif
+ switch (op) {
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+ }
+ }
+ }
+ o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
+ hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
+ hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums(1, nrows, 1);
+ const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
}
}
-static void im2col_f32_f16_cuda(const float* x, half* dst,
+template <typename T>
+static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
- int offset_delta,
+ int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
- dim3 block_nums(num_blocks, OH, IC);
- im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+ dim3 block_nums(num_blocks, OH, batch * IC);
+ im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
(void) src1_dd;
}
+static void ggml_cuda_op_hardsigmoid(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static void ggml_cuda_op_hardswish(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
static void ggml_cuda_op_leaky_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
(void) src1_dd;
}
+static void ggml_cuda_op_pool2d(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ const int64_t IH = src0->ne[1];
+ const int64_t IW = src0->ne[0];
+
+ const int64_t N = dst->ne[3];
+ const int64_t OC = dst->ne[2];
+ const int64_t OH = dst->ne[1];
+ const int64_t OW = dst->ne[0];
+
+ const int parallel_elements = N * OC * OH * OW;
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+ dim3 block_nums(num_blocks);
+ pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
+
+ (void) src1;
+ (void) src1_dd;
+}
+
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const int64_t batch = src1->ne[3];
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
- im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ } else {
+ im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ }
(void) src0;
(void) src0_dd;
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
+static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
+}
+
+static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
+}
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
}
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
+static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
+}
+
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ func = ggml_cuda_hardsigmoid;
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ func = ggml_cuda_hardswish;
+ break;
default:
return false;
}
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
+ case GGML_OP_POOL_2D:
+ func = ggml_cuda_pool2d;
+ break;
case GGML_OP_SUM_ROWS:
func = ggml_cuda_sum_rows;
break;
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
return true;
case GGML_OP_ROPE:
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
+ case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
int s0,
int p0,
int d0) {
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
int p1,
int d0,
int d1) {
+
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
- s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
-
- struct ggml_tensor * result =
- ggml_mul_mat(ctx,
- ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
- ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+ s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+ new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+ struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
int p1,
int d0,
int d1,
- bool is_2D) {
+ bool is_2D,
+ enum ggml_type dst_type) {
if(is_2D) {
GGML_ASSERT(a->ne[2] == b->ne[2]);
is_2D ? b->ne[3] : 1,
};
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
ggml_set_op_params(result, params, sizeof(params));
int p1,
int d0,
int d1) {
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
is_node = true;
}
+ struct ggml_tensor * result;
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+ result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_POOL_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
-
return result;
}
}
}
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst: result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = is_2D ? ne13 : ne12;
+ const int64_t IC = is_2D ? ne12 : ne11;
+ const int64_t IH = is_2D ? ne11 : 1;
+ const int64_t IW = ne10;
+
+ const int64_t KH = is_2D ? ne01 : 1;
+ const int64_t KW = ne00;
+
+ const int64_t OH = is_2D ? ne2 : 1;
+ const int64_t OW = ne1;
+
+ int ofs0 = is_2D ? nb13 : nb12;
+ int ofs1 = is_2D ? nb12 : nb11;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+ {
+ float * const wdata = (float *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+ } else {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- switch (src0->type) {
+ switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
- GGML_ASSERT(false);
+ ggml_compute_forward_im2col_f32(params, src0, src1, dst);
} break;
default:
{
const struct ggml_compute_params * params,
const struct ggml_tensor * src,
struct ggml_tensor * dst) {
- assert(src->type == GGML_TYPE_F32);
- assert(params->ith == 0);
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;