dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
-static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
+template <typename T>
+static void im2col_kernel(const float *x, T *dst, int offset_delta,
int IW, int IH, int OW, int KW, int KH,
int pelements, int CHW, int s0, int s1, int p0,
int p1, int d0, int d1,
});
}
-static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
+template <typename T>
+static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC,
int offset_delta, int s0, int s1, int p0,
int p1, int d0, int d1,
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH,
+ im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
- IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ if (dst->type == GGML_TYPE_F16) {
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ } else {
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ }
(void) src0;
(void) src0_dd;