const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
- const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- const int64_t batch = src1->ne[3];
- const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const int64_t batch = src1->ne[is_2D ? 3 : 2];
+ const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
}
}
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
- test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
- // test cases for 1D im2col
+ // im2col 1D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+ for (int s0 : {1, 3}) {
+ for (int p0 : {0, 3}) {
+ for (int d0 : {1, 3}) {
+ test_cases.emplace_back(new test_im2col(
+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
+ s0, 0, p0, 0, d0, 0, false));
+ }
+ }
+ }
+
+ // im2col 2D
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+ for (int s0 : {1, 3}) {
+ for (int s1 : {1, 3}) {
+ for (int p0 : {0, 3}) {
+ for (int p1 : {0, 3}) {
+ for (int d0 : {1, 3}) {
+ for (int d1 : {1, 3}) {
+ test_cases.emplace_back(new test_im2col(
+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
+ s0, s1, p0, p1, d0, d1, true));
+ }
+ }
+ }
+ }
+ }
+ }
// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)