From: Acly Date: Thu, 17 Apr 2025 12:16:45 +0000 (+0200) Subject: ggml : Depthwise 2D convolution (#1152) X-Git-Tag: upstream/0.0.1982~31 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=eb22d6d7d40c267c1d024db5452976db55b7f28f;p=pkg%2Fggml%2Fsources%2Fggml ggml : Depthwise 2D convolution (#1152) * ggml-cpu : kernels for faster depthwise 2D convolution * fix compile: remove static after moving to ops.cpp * add dilation for depthwise_conv_2d * review: rename to ggml_conv_2d_dw_direct, remove redundant struct keywords, pass by ref, whitespace * review: rename depthwise_conv_2d -> conv_2d_dw everywhere --- diff --git a/include/ggml.h b/include/ggml.h index 8fcc16df..51aa5b3a 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -481,6 +481,7 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, @@ -677,6 +678,9 @@ extern "C" { GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN + GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -1660,7 +1664,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // depthwise + // depthwise (via im2col and mul_mat) GGML_API struct ggml_tensor * ggml_conv_2d_dw( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1672,6 +1676,22 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + // Depthwise 2D convolution + // may be faster than ggml_conv_2d_dw, but not available in all backends + // a: KW KH 1 C convolution kernel + // b: W H C N input data + // res: W_out H_out C N + GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1); + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index 50400328..dbad8f61 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -1932,6 +1932,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_CONV_2D_DW: + { + ggml_compute_forward_conv_2d_dw(params, tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -2268,6 +2272,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index 6050147b..3c2adb21 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d( } } +// ggml_compute_forward_conv_2d_dw + +struct ggml_conv_2d_dw_params { + int64_t channels; + int64_t batch; + int64_t src_w; + int64_t src_h; + int64_t dst_w; + int64_t dst_h; + int64_t knl_w; + int64_t knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +}; + +static void ggml_compute_forward_conv_2d_dw_cwhn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t c = p.channels; + const float * knl_data = (const float *)kernel->data; + + const int64_t rows_total = p.dst_h * p.batch; + const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth; + const int64_t row_start = params->ith * rows_per_thread; + const int64_t row_end = MIN(row_start + rows_per_thread, rows_total); + +#ifdef GGML_SIMD + const int64_t pkg_size = GGML_F32_EPR; + const int64_t pkg_count = c / pkg_size; + const int64_t c_pkg_end = pkg_count * pkg_size; +#else + const int64_t c_pkg_end = 0; +#endif + + for (int64_t row = row_start; row < row_end; ++row) { + const int64_t dst_y = row % p.dst_h; + const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c; + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c; + const int64_t src_y_base = dst_y * p.stride_y - p.pad_y; + const int64_t src_x_base = dst_x * p.stride_x - p.pad_x; + +#ifdef GGML_SIMD + // Vectorized loop + for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) { + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i); + GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i); + sum = GGML_F32_VEC_FMA(sum, k, s); + } + } + GGML_F32_VEC_STORE(dst_data + c_i, sum); + } +#endif + // Scalar loop + for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) { + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i] + * src_data[(src_y * p.src_w + src_x) * c + c_i]; + } + } + dst_data[c_i] = sum; + } + } + } +} + +static void ggml_compute_forward_conv_2d_dw_whcn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t n = p.channels * p.batch; + const int64_t per_thread = (n + params->nth - 1) / params->nth; + const int64_t start = params->ith * per_thread; + const int64_t end = MIN(start + per_thread, n); + + for (int64_t i = start; i < end; ++i) { + const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h; + const float * src_data = (const float *)src->data + i * p.src_w * p.src_h; + float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h; + + for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) { + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[knl_y * p.knl_w + knl_x] + * src_data[src_y * p.src_w + src_x]; + } + } + dst_data[dst_y * p.dst_w + dst_x] = sum; + } + } + } +} + +void ggml_compute_forward_conv_2d_dw( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * src = dst->src[1]; + ggml_conv_2d_dw_params p; + p.channels = src->ne[2]; + p.batch = src->ne[3]; + p.src_w = src->ne[0]; + p.src_h = src->ne[1]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.knl_w = kernel->ne[0]; + p.knl_h = kernel->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(kernel->ne[3] == p.channels); + GGML_ASSERT(dst->ne[3] == p.batch); + + if (ggml_is_contiguous(src)) { + ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p); + } else if (ggml_is_contiguous_channels(src)) { + // kernel should also have channels most contiguous in memory + GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]); + ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p); + } else { + GGML_ABORT("non-contiguous memory layout not supported"); + } +} + // ggml_compute_forward_pool_1d_sk_p0 static void ggml_compute_forward_pool_1d_sk_p0( diff --git a/src/ggml-cpu/ops.h b/src/ggml-cpu/ops.h index 410a3720..dc081b9e 100644 --- a/src/ggml-cpu/ops.h +++ b/src/ggml-cpu/ops.h @@ -65,6 +65,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/src/ggml.c b/src/ggml.c index 950772c7..c8b2feff 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -956,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -993,7 +994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1087,7 +1089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1344,6 +1346,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; } +bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) { + return + tensor->nb[0] > tensor->nb[2] && + tensor->nb[1] > tensor->nb[0] && + tensor->nb[2] == ggml_type_size(tensor->type); +} + static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -4050,6 +4059,46 @@ struct ggml_tensor * ggml_conv_2d_dw( return result; } +// ggml_conv_2d_dw_direct + +struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1) { + GGML_ASSERT(a->ne[2] == 1); + GGML_ASSERT(a->ne[3] == b->ne[2]); + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1); + ne[2] = b->ne[2]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + if (ggml_is_contiguous_channels(b)) { + // Result will be permuted the same way as input (CWHN order) + const int64_t type_size = ggml_type_size(result->type); + GGML_ASSERT(ggml_blck_size(result->type) == 1); + result->nb[0] = result->ne[2] * type_size; + result->nb[1] = result->ne[0] * result->nb[0]; + result->nb[2] = type_size; + } + + int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D_DW; + result->src[0] = a; + result->src[1] = b; + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5db778cd..27f398ac 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -384,6 +384,16 @@ add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") +# +# test-conv2d-dw + +set(TEST_TARGET test-conv2d-dw) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + + # # test-mul-mat diff --git a/tests/test-conv2d-dw.cpp b/tests/test-conv2d-dw.cpp new file mode 100644 index 00000000..d4c02908 --- /dev/null +++ b/tests/test-conv2d-dw.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +std::vector f32_range(int n, float start, float end) { + std::vector values(n); + float step = (end - start) / n; + for (int i = 0; i < n; i++) { + values[i] = start + i * step; + } + return values; +} + +// Most straightforward implementation without any optimizations +std::vector conv_2d_dw_reference( + int src_w, int src_h, const float * src_data, + int knl_w, int knl_h, const float * knl_data, + int channels, int batch, int stride, int pad, int dilation) { + + int dst_w = (src_w + 2 * pad - dilation * (knl_w - 1) - 1) / stride + 1; + int dst_h = (src_h + 2 * pad - dilation * (knl_h - 1) - 1) / stride + 1; + std::vector dst_data(dst_w * dst_h * channels * batch); + + for (int b = 0; b < batch; b++) { + const float * src_base = src_data + b * src_w * src_h * channels; + float * dst_base = dst_data.data() + b * dst_w * dst_h * channels; + for (int c = 0; c < channels; c++) { + for (int y = 0; y < dst_h; y++) { + for (int x = 0; x < dst_w; x++) { + float sum = 0; + for (int knl_y = 0; knl_y < knl_h; knl_y++) { + for (int knl_x = 0; knl_x < knl_w; knl_x++) { + int src_x = x * stride + knl_x * dilation - pad; + int src_y = y * stride + knl_y * dilation - pad; + if (src_x >= 0 && src_x < src_w && src_y >= 0 && src_y < src_h) { + sum += src_base[c * src_w * src_h + src_y * src_w + src_x] * + knl_data[c * knl_w * knl_h + knl_y * knl_w + knl_x]; + } + } + } + dst_base[c * dst_w * dst_h + y * dst_w + x] = sum; + } + } + } + } + return dst_data; +} + +bool check_equal(const std::vector & result, const std::vector & expected) { + if (result.size() != expected.size()) { + printf("result.size() = %d, expected.size() = %d\n", (int)result.size(), (int)expected.size()); + return false; + } + for (int i = 0; i < result.size(); i++) { + if(std::abs(result[i] - expected[i]) > 1e-5) { + printf("result[%d] %f != %f expected[%d]\n", i, result[i], expected[i], i); + return false; + } + } + return true; +} + +bool test_conv_2d_dw( + int channels, + int kernel_size, + int stride, + int pad, + int dilation, + bool contiguous_channels) { + ggml_time_init(); + + const int batch = 2; + const int src_w = 8; + const int src_h = 6; + const int knl_w = kernel_size; + const int knl_h = kernel_size; + + ggml_init_params params { + /*.mem_size =*/ 64 * ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true + }; + + ggml_context_ptr ctx_ptr{ggml_init(params)}; + ggml_context * ctx = ctx_ptr.get(); + ggml_cgraph * gf = ggml_new_graph(ctx); + + // Build graph + ggml_tensor * src_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, src_w, src_h, channels, batch); + ggml_tensor * knl_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, knl_w, knl_h, 1, channels); + ggml_tensor * src = src_input; + ggml_tensor * knl = knl_input; + if (contiguous_channels) { + // Convert tensor to [C, W, H, N] layout in memory, then permute strides back to [W, H, C, N] + src = ggml_cont(ctx, ggml_permute(ctx, src, 1, 2, 0, 3)); + src = ggml_permute(ctx, src, 2, 0, 1, 3); + knl = ggml_cont(ctx, ggml_permute(ctx, knl, 2, 3, 1, 0)); + knl = ggml_permute(ctx, knl, 3, 2, 0, 1); + } + ggml_tensor * res = ggml_conv_2d_dw_direct( + ctx, knl, src, stride, stride, pad, pad, dilation, dilation); + if (contiguous_channels) { + res = ggml_cont(ctx, res); + } + ggml_build_forward_expand(gf, res); + + // Create backend & allocate buffers + ggml_backend_ptr backend_ptr{ggml_backend_cpu_init()}; + ggml_backend_t backend = backend_ptr.get(); + ggml_backend_cpu_set_n_threads(backend, 2); + ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors(ctx, backend)}; + + std::vector src_values = f32_range(ggml_nelements(src), -1.f, 1.f); + std::vector knl_values = f32_range(ggml_nelements(knl), -1.f, 1.f); + ggml_backend_tensor_set(src_input, src_values.data(), 0, ggml_nbytes(src)); + ggml_backend_tensor_set(knl_input, knl_values.data(), 0, ggml_nbytes(knl)); + + ggml_backend_graph_compute(backend, gf); + + std::vector res_values(ggml_nelements(res)); + ggml_backend_tensor_get(res, res_values.data(), 0, ggml_nbytes(res)); + + std::vector expected = conv_2d_dw_reference( + src_w, src_h, src_values.data(), + knl_w, knl_h, knl_values.data(), + channels, batch, stride, pad, dilation); + + bool passed = check_equal(res_values, expected); + + printf("ggml_conv_2d_dw(channels=%d, kernel=%dx%d, stride=%d, pad=%d, dilation=%d, layout=%s): %s\n", + channels, kernel_size, kernel_size, stride, pad, dilation, contiguous_channels ? "CWHN" : "WHCN", + passed ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + return passed; +} + +int main(int argc, char ** argv) { + bool passed = true; + passed = test_conv_2d_dw(3, 1, 1, 0, 1, false) && passed; + passed = test_conv_2d_dw(3, 1, 1, 0, 1, true) && passed; + passed = test_conv_2d_dw(42, 3, 2, 1, 1, false) && passed; + passed = test_conv_2d_dw(42, 3, 2, 1, 1, true) && passed; + passed = test_conv_2d_dw(8, 5, 1, 2, 2, false) && passed; + passed = test_conv_2d_dw(8, 5, 1, 2, 2, true) && passed; + return passed ? 0 : 1; +}