]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : Depthwise 2D convolution (#1152)
authorAcly <redacted>
Thu, 17 Apr 2025 12:16:45 +0000 (14:16 +0200)
committerGitHub <redacted>
Thu, 17 Apr 2025 12:16:45 +0000 (15:16 +0300)
* ggml-cpu : kernels for faster depthwise 2D convolution

* fix compile: remove static after moving to ops.cpp

* add dilation for depthwise_conv_2d

* review: rename to ggml_conv_2d_dw_direct, remove redundant struct keywords, pass by ref, whitespace

* review: rename depthwise_conv_2d -> conv_2d_dw everywhere

include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml.c
tests/CMakeLists.txt
tests/test-conv2d-dw.cpp [new file with mode: 0644]

index 8fcc16df998bee004bb4d7b3c8739e5032c387ec..51aa5b3a0ab44309a1c6dbc0de0ebdd438b8e6c1 100644 (file)
@@ -481,6 +481,7 @@ extern "C" {
         GGML_OP_CONV_TRANSPOSE_1D,
         GGML_OP_IM2COL,
         GGML_OP_IM2COL_BACK,
+        GGML_OP_CONV_2D_DW,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
@@ -677,6 +678,9 @@ extern "C" {
     GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
     GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
 
+    // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
+    GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
+
     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
@@ -1660,7 +1664,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    // depthwise
+    // depthwise (via im2col and mul_mat)
     GGML_API struct ggml_tensor * ggml_conv_2d_dw(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,  // convolution kernel
@@ -1672,6 +1676,22 @@ extern "C" {
             int                  d0,  // dilation dimension 0
             int                  d1); // dilation dimension 1
 
+    // Depthwise 2D convolution
+    // may be faster than ggml_conv_2d_dw, but not available in all backends
+    // a:   KW    KH    1    C    convolution kernel
+    // b:   W     H     C    N    input data
+    // res: W_out H_out C    N
+    GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   stride0,
+            int                   stride1,
+            int                   pad0,
+            int                   pad1,
+            int                   dilation0,
+            int                   dilation1);
+
     GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 50400328738efed580c5fc21564475846b905196..dbad8f61a1e92ff5288eaf34aaaa7b48b462e1be 100644 (file)
@@ -1932,6 +1932,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_im2col_back_f32(params, tensor);
             } break;
+        case GGML_OP_CONV_2D_DW:
+            {
+                ggml_compute_forward_conv_2d_dw(params, tensor);
+            } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -2268,6 +2272,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_IM2COL:
         case GGML_OP_IM2COL_BACK:
+        case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
index 6050147be70accd7fd25df343012ca4e54b48702..3c2adb217267b599ba003dc3f1bda0fbc7765c74 100644 (file)
@@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d(
     }
 }
 
+// ggml_compute_forward_conv_2d_dw
+
+struct ggml_conv_2d_dw_params {
+    int64_t channels;
+    int64_t batch;
+    int64_t src_w;
+    int64_t src_h;
+    int64_t dst_w;
+    int64_t dst_h;
+    int64_t knl_w;
+    int64_t knl_h;
+    int stride_x;
+    int stride_y;
+    int pad_x;
+    int pad_y;
+    int dilation_x;
+    int dilation_y;
+};
+
+static void ggml_compute_forward_conv_2d_dw_cwhn(
+        const ggml_compute_params * params,
+        const ggml_tensor * src,
+        const ggml_tensor * kernel,
+        ggml_tensor * dst,
+        const ggml_conv_2d_dw_params & p) {
+
+    const int64_t c = p.channels;
+    const float * knl_data = (const float *)kernel->data;
+
+    const int64_t rows_total = p.dst_h * p.batch;
+    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
+    const int64_t row_start = params->ith * rows_per_thread;
+    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
+
+#ifdef GGML_SIMD
+    const int64_t pkg_size = GGML_F32_EPR;
+    const int64_t pkg_count = c / pkg_size;
+    const int64_t c_pkg_end = pkg_count * pkg_size;
+#else
+    const int64_t c_pkg_end = 0;
+#endif
+
+    for (int64_t row = row_start; row < row_end; ++row) {
+        const int64_t dst_y = row % p.dst_h;
+        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
+        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
+            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
+            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
+            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
+
+#ifdef GGML_SIMD
+            // Vectorized loop
+            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
+                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
+                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
+                    if (src_y < 0 || src_y >= p.src_h) {
+                        continue;
+                    }
+                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
+                        if (src_x < 0 || src_x >= p.src_w) {
+                            continue;
+                        }
+                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
+                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
+                        sum = GGML_F32_VEC_FMA(sum, k, s);
+                    }
+                }
+                GGML_F32_VEC_STORE(dst_data + c_i, sum);
+            }
+#endif
+            // Scalar loop
+            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
+                float sum = 0.0f;
+                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
+                    if (src_y < 0 || src_y >= p.src_h) {
+                        continue;
+                    }
+                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
+                        if (src_x < 0 || src_x >= p.src_w) {
+                            continue;
+                        }
+                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
+                             * src_data[(src_y * p.src_w + src_x) * c + c_i];
+                    }
+                }
+                dst_data[c_i] = sum;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_2d_dw_whcn(
+        const ggml_compute_params * params,
+        const ggml_tensor * src,
+        const ggml_tensor * kernel,
+        ggml_tensor * dst,
+        const ggml_conv_2d_dw_params & p) {
+
+    const int64_t n = p.channels * p.batch;
+    const int64_t per_thread = (n + params->nth - 1) / params->nth;
+    const int64_t start = params->ith * per_thread;
+    const int64_t end = MIN(start + per_thread, n);
+
+    for (int64_t i = start; i < end; ++i) {
+        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
+        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
+        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
+
+        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
+            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
+
+                float sum = 0.0f;
+                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+                    if (src_y < 0 || src_y >= p.src_h) {
+                        continue;
+                    }
+                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+                        if (src_x < 0 || src_x >= p.src_w) {
+                            continue;
+                        }
+                        sum += knl_data[knl_y * p.knl_w + knl_x]
+                             * src_data[src_y * p.src_w + src_x];
+                    }
+                }
+                dst_data[dst_y * p.dst_w + dst_x] = sum;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_conv_2d_dw(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * kernel = dst->src[0];
+    const ggml_tensor * src = dst->src[1];
+    ggml_conv_2d_dw_params p;
+    p.channels = src->ne[2];
+    p.batch = src->ne[3];
+    p.src_w = src->ne[0];
+    p.src_h = src->ne[1];
+    p.dst_w = dst->ne[0];
+    p.dst_h = dst->ne[1];
+    p.knl_w = kernel->ne[0];
+    p.knl_h = kernel->ne[1];
+    p.stride_x = dst->op_params[0];
+    p.stride_y = dst->op_params[1];
+    p.pad_x = dst->op_params[2];
+    p.pad_y = dst->op_params[3];
+    p.dilation_x = dst->op_params[4];
+    p.dilation_y = dst->op_params[5];
+
+    GGML_ASSERT(kernel->ne[3] == p.channels);
+    GGML_ASSERT(dst->ne[3] == p.batch);
+
+    if (ggml_is_contiguous(src)) {
+        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
+    } else if (ggml_is_contiguous_channels(src)) {
+        // kernel should also have channels most contiguous in memory
+        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
+        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
+    } else {
+        GGML_ABORT("non-contiguous memory layout not supported");
+    }
+}
+
 // ggml_compute_forward_pool_1d_sk_p0
 
 static void ggml_compute_forward_pool_1d_sk_p0(
index 410a372047a01032fc974ed4fb3fe2368a13d125..dc081b9e663970458606651a43bd640a187641c2 100644 (file)
@@ -65,6 +65,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
 void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 950772c75cb32867c5621b33fde4ca95eacc2fa6..c8b2feff4251da94ecdc74236eb064fb8fc025dc 100644 (file)
@@ -956,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CONV_TRANSPOSE_1D",
     "IM2COL",
     "IM2COL_BACK",
+    "CONV_2D_DW",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
@@ -993,7 +994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "conv_transpose_1d(x)",
     "im2col(x)",
     "im2col_back(x)",
+    "conv_2d_dw(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
@@ -1087,7 +1089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1344,6 +1346,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
 }
 
+bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
+    return
+        tensor->nb[0] > tensor->nb[2] &&
+        tensor->nb[1] > tensor->nb[0] &&
+        tensor->nb[2] == ggml_type_size(tensor->type);
+}
+
 static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -4050,6 +4059,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
     return result;
 }
 
+// ggml_conv_2d_dw_direct
+
+struct ggml_tensor * ggml_conv_2d_dw_direct(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   stride0,
+        int                   stride1,
+        int                   pad0,
+        int                   pad1,
+        int                   dilation0,  
+        int                   dilation1) {
+    GGML_ASSERT(a->ne[2] == 1);
+    GGML_ASSERT(a->ne[3] == b->ne[2]);
+    int64_t ne[4];
+    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
+    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
+    ne[2] = b->ne[2];
+    ne[3] = b->ne[3];
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
+
+    if (ggml_is_contiguous_channels(b)) {
+        // Result will be permuted the same way as input (CWHN order)
+        const int64_t type_size = ggml_type_size(result->type);
+        GGML_ASSERT(ggml_blck_size(result->type) == 1);
+        result->nb[0] = result->ne[2] * type_size;
+        result->nb[1] = result->ne[0] * result->nb[0];
+        result->nb[2] = type_size;
+    }
+
+    int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_CONV_2D_DW;
+    result->src[0] = a;
+    result->src[1] = b;
+    return result;
+}
+
 // ggml_conv_transpose_2d_p0
 
 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
index 5db778cd861b589385f2dbbfa99a35305eea161a..27f398ac9bca5109ebffbd65f92e4dda2215567c 100644 (file)
@@ -384,6 +384,16 @@ add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
 
 
+#
+# test-conv2d-dw
+
+set(TEST_TARGET test-conv2d-dw)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+
 #
 # test-mul-mat
 
diff --git a/tests/test-conv2d-dw.cpp b/tests/test-conv2d-dw.cpp
new file mode 100644 (file)
index 0000000..d4c0290
--- /dev/null
@@ -0,0 +1,153 @@
+#include <ggml.h>
+#include <ggml-cpu.h>
+#include <ggml-alloc.h>
+#include <ggml-backend.h>
+#include <ggml-cpp.h>
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <vector>
+
+std::vector<float> f32_range(int n, float start, float end) {
+    std::vector<float> values(n);
+    float step = (end - start) / n;
+    for (int i = 0; i < n; i++) {
+        values[i] = start + i * step;
+    }
+    return values;
+}
+
+// Most straightforward implementation without any optimizations
+std::vector<float> conv_2d_dw_reference(
+        int src_w, int src_h, const float * src_data,
+        int knl_w, int knl_h, const float * knl_data,
+        int channels, int batch, int stride, int pad, int dilation) {
+
+    int dst_w = (src_w + 2 * pad - dilation * (knl_w - 1) - 1) / stride + 1;
+    int dst_h = (src_h + 2 * pad - dilation * (knl_h - 1) - 1) / stride + 1;
+    std::vector<float> dst_data(dst_w * dst_h * channels * batch);
+
+    for (int b = 0; b < batch; b++) {
+        const float * src_base = src_data + b * src_w * src_h * channels;
+        float * dst_base = dst_data.data() + b * dst_w * dst_h * channels;
+        for (int c = 0; c < channels; c++) {
+            for (int y = 0; y < dst_h; y++) {
+                for (int x = 0; x < dst_w; x++) {
+                    float sum = 0;
+                    for (int knl_y = 0; knl_y < knl_h; knl_y++) {
+                        for (int knl_x = 0; knl_x < knl_w; knl_x++) {
+                            int src_x = x * stride + knl_x * dilation - pad;
+                            int src_y = y * stride + knl_y * dilation - pad;
+                            if (src_x >= 0 && src_x < src_w && src_y >= 0 && src_y < src_h) {
+                                sum += src_base[c * src_w * src_h + src_y * src_w + src_x] *
+                                       knl_data[c * knl_w * knl_h + knl_y * knl_w + knl_x];
+                            }
+                        }
+                    }
+                    dst_base[c * dst_w * dst_h + y * dst_w + x] = sum;
+                }
+            }
+        }
+    }
+    return dst_data;
+}
+
+bool check_equal(const std::vector<float> & result, const std::vector<float> & expected) {
+    if (result.size() != expected.size()) {
+        printf("result.size() = %d, expected.size() = %d\n", (int)result.size(), (int)expected.size());
+        return false;
+    }
+    for (int i = 0; i < result.size(); i++) {
+        if(std::abs(result[i] - expected[i]) > 1e-5) {
+            printf("result[%d] %f != %f expected[%d]\n", i, result[i], expected[i], i);
+            return false;
+        }
+    }
+    return true;
+}
+
+bool test_conv_2d_dw(
+        int channels,
+        int kernel_size,
+        int stride,
+        int pad,
+        int dilation,
+        bool contiguous_channels) {
+    ggml_time_init();
+
+    const int batch = 2;
+    const int src_w = 8;
+    const int src_h = 6;
+    const int knl_w = kernel_size;
+    const int knl_h = kernel_size;
+
+    ggml_init_params params {
+        /*.mem_size   =*/ 64 * ggml_tensor_overhead() + ggml_graph_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true
+    };
+
+    ggml_context_ptr ctx_ptr{ggml_init(params)};
+    ggml_context * ctx = ctx_ptr.get();
+    ggml_cgraph * gf = ggml_new_graph(ctx);
+
+    // Build graph
+    ggml_tensor * src_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, src_w, src_h, channels, batch);
+    ggml_tensor * knl_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, knl_w, knl_h, 1, channels);
+    ggml_tensor * src = src_input;
+    ggml_tensor * knl = knl_input;
+    if (contiguous_channels) {
+        // Convert tensor to [C, W, H, N] layout in memory, then permute strides back to [W, H, C, N]
+        src = ggml_cont(ctx, ggml_permute(ctx, src, 1, 2, 0, 3));
+        src = ggml_permute(ctx, src, 2, 0, 1, 3);
+        knl = ggml_cont(ctx, ggml_permute(ctx, knl, 2, 3, 1, 0));
+        knl = ggml_permute(ctx, knl, 3, 2, 0, 1);
+    }
+    ggml_tensor * res = ggml_conv_2d_dw_direct(
+        ctx, knl, src, stride, stride, pad, pad, dilation, dilation);
+    if (contiguous_channels) {
+        res = ggml_cont(ctx, res);
+    }
+    ggml_build_forward_expand(gf, res);
+
+    // Create backend & allocate buffers
+    ggml_backend_ptr backend_ptr{ggml_backend_cpu_init()};
+    ggml_backend_t backend = backend_ptr.get();
+    ggml_backend_cpu_set_n_threads(backend, 2);
+    ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors(ctx, backend)};
+
+    std::vector<float> src_values = f32_range(ggml_nelements(src), -1.f, 1.f);
+    std::vector<float> knl_values = f32_range(ggml_nelements(knl), -1.f, 1.f);
+    ggml_backend_tensor_set(src_input, src_values.data(), 0, ggml_nbytes(src));
+    ggml_backend_tensor_set(knl_input, knl_values.data(), 0, ggml_nbytes(knl));
+
+    ggml_backend_graph_compute(backend, gf);
+
+    std::vector<float> res_values(ggml_nelements(res));
+    ggml_backend_tensor_get(res, res_values.data(), 0, ggml_nbytes(res));
+
+    std::vector<float> expected = conv_2d_dw_reference(
+        src_w, src_h, src_values.data(),
+        knl_w, knl_h, knl_values.data(),
+        channels, batch, stride, pad, dilation);
+
+    bool passed = check_equal(res_values, expected);
+
+    printf("ggml_conv_2d_dw(channels=%d, kernel=%dx%d, stride=%d, pad=%d, dilation=%d, layout=%s): %s\n",
+        channels, kernel_size, kernel_size, stride, pad, dilation, contiguous_channels ? "CWHN" : "WHCN",
+        passed ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+    return passed;
+}
+
+int main(int argc, char ** argv) {
+    bool passed = true;
+    passed = test_conv_2d_dw(3, 1, 1, 0, 1, false) && passed;
+    passed = test_conv_2d_dw(3, 1, 1, 0, 1, true) && passed;
+    passed = test_conv_2d_dw(42, 3, 2, 1, 1, false) && passed;
+    passed = test_conv_2d_dw(42, 3, 2, 1, 1, true) && passed;
+    passed = test_conv_2d_dw(8, 5, 1, 2, 2, false) && passed;
+    passed = test_conv_2d_dw(8, 5, 1, 2, 2, true) && passed;
+    return passed ? 0 : 1;
+}