GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
+ GGML_OP_CONV_2D_DW,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+ // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
+ GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
+
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
struct ggml_tensor * a,
struct ggml_tensor * b);
- // depthwise
+ // depthwise (via im2col and mul_mat)
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
int d0, // dilation dimension 0
int d1); // dilation dimension 1
+ // Depthwise 2D convolution
+ // may be faster than ggml_conv_2d_dw, but not available in all backends
+ // a: KW KH 1 C convolution kernel
+ // b: W H C N input data
+ // res: W_out H_out C N
+ GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int stride0,
+ int stride1,
+ int pad0,
+ int pad1,
+ int dilation0,
+ int dilation1);
+
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
struct ggml_context * ctx,
struct ggml_tensor * a,
}
}
+// ggml_compute_forward_conv_2d_dw
+
+struct ggml_conv_2d_dw_params {
+ int64_t channels;
+ int64_t batch;
+ int64_t src_w;
+ int64_t src_h;
+ int64_t dst_w;
+ int64_t dst_h;
+ int64_t knl_w;
+ int64_t knl_h;
+ int stride_x;
+ int stride_y;
+ int pad_x;
+ int pad_y;
+ int dilation_x;
+ int dilation_y;
+};
+
+static void ggml_compute_forward_conv_2d_dw_cwhn(
+ const ggml_compute_params * params,
+ const ggml_tensor * src,
+ const ggml_tensor * kernel,
+ ggml_tensor * dst,
+ const ggml_conv_2d_dw_params & p) {
+
+ const int64_t c = p.channels;
+ const float * knl_data = (const float *)kernel->data;
+
+ const int64_t rows_total = p.dst_h * p.batch;
+ const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
+ const int64_t row_start = params->ith * rows_per_thread;
+ const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
+
+#ifdef GGML_SIMD
+ const int64_t pkg_size = GGML_F32_EPR;
+ const int64_t pkg_count = c / pkg_size;
+ const int64_t c_pkg_end = pkg_count * pkg_size;
+#else
+ const int64_t c_pkg_end = 0;
+#endif
+
+ for (int64_t row = row_start; row < row_end; ++row) {
+ const int64_t dst_y = row % p.dst_h;
+ const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
+ for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
+ float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
+ const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
+ const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
+
+#ifdef GGML_SIMD
+ // Vectorized loop
+ for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y;
+ if (src_y < 0 || src_y >= p.src_h) {
+ continue;
+ }
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x;
+ if (src_x < 0 || src_x >= p.src_w) {
+ continue;
+ }
+ GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
+ GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
+ sum = GGML_F32_VEC_FMA(sum, k, s);
+ }
+ }
+ GGML_F32_VEC_STORE(dst_data + c_i, sum);
+ }
+#endif
+ // Scalar loop
+ for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
+ float sum = 0.0f;
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y;
+ if (src_y < 0 || src_y >= p.src_h) {
+ continue;
+ }
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x;
+ if (src_x < 0 || src_x >= p.src_w) {
+ continue;
+ }
+ sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
+ * src_data[(src_y * p.src_w + src_x) * c + c_i];
+ }
+ }
+ dst_data[c_i] = sum;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_2d_dw_whcn(
+ const ggml_compute_params * params,
+ const ggml_tensor * src,
+ const ggml_tensor * kernel,
+ ggml_tensor * dst,
+ const ggml_conv_2d_dw_params & p) {
+
+ const int64_t n = p.channels * p.batch;
+ const int64_t per_thread = (n + params->nth - 1) / params->nth;
+ const int64_t start = params->ith * per_thread;
+ const int64_t end = MIN(start + per_thread, n);
+
+ for (int64_t i = start; i < end; ++i) {
+ const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
+ const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
+ float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
+
+ for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
+ for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
+
+ float sum = 0.0f;
+ for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+ if (src_y < 0 || src_y >= p.src_h) {
+ continue;
+ }
+ for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+ if (src_x < 0 || src_x >= p.src_w) {
+ continue;
+ }
+ sum += knl_data[knl_y * p.knl_w + knl_x]
+ * src_data[src_y * p.src_w + src_x];
+ }
+ }
+ dst_data[dst_y * p.dst_w + dst_x] = sum;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_2d_dw(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * src = dst->src[1];
+ ggml_conv_2d_dw_params p;
+ p.channels = src->ne[2];
+ p.batch = src->ne[3];
+ p.src_w = src->ne[0];
+ p.src_h = src->ne[1];
+ p.dst_w = dst->ne[0];
+ p.dst_h = dst->ne[1];
+ p.knl_w = kernel->ne[0];
+ p.knl_h = kernel->ne[1];
+ p.stride_x = dst->op_params[0];
+ p.stride_y = dst->op_params[1];
+ p.pad_x = dst->op_params[2];
+ p.pad_y = dst->op_params[3];
+ p.dilation_x = dst->op_params[4];
+ p.dilation_y = dst->op_params[5];
+
+ GGML_ASSERT(kernel->ne[3] == p.channels);
+ GGML_ASSERT(dst->ne[3] == p.batch);
+
+ if (ggml_is_contiguous(src)) {
+ ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
+ } else if (ggml_is_contiguous_channels(src)) {
+ // kernel should also have channels most contiguous in memory
+ GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
+ ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
+ } else {
+ GGML_ABORT("non-contiguous memory layout not supported");
+ }
+}
+
// ggml_compute_forward_pool_1d_sk_p0
static void ggml_compute_forward_pool_1d_sk_p0(
"CONV_TRANSPOSE_1D",
"IM2COL",
"IM2COL_BACK",
+ "CONV_2D_DW",
"CONV_TRANSPOSE_2D",
"POOL_1D",
"POOL_2D",
"OPT_STEP_ADAMW",
};
-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"conv_transpose_1d(x)",
"im2col(x)",
"im2col_back(x)",
+ "conv_2d_dw(x)",
"conv_transpose_2d(x)",
"pool_1d(x)",
"pool_2d(x)",
"adamw(x)",
};
-static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
}
+bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
+ return
+ tensor->nb[0] > tensor->nb[2] &&
+ tensor->nb[1] > tensor->nb[0] &&
+ tensor->nb[2] == ggml_type_size(tensor->type);
+}
+
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return result;
}
+// ggml_conv_2d_dw_direct
+
+struct ggml_tensor * ggml_conv_2d_dw_direct(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int stride0,
+ int stride1,
+ int pad0,
+ int pad1,
+ int dilation0,
+ int dilation1) {
+ GGML_ASSERT(a->ne[2] == 1);
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
+ int64_t ne[4];
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
+ ne[2] = b->ne[2];
+ ne[3] = b->ne[3];
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
+
+ if (ggml_is_contiguous_channels(b)) {
+ // Result will be permuted the same way as input (CWHN order)
+ const int64_t type_size = ggml_type_size(result->type);
+ GGML_ASSERT(ggml_blck_size(result->type) == 1);
+ result->nb[0] = result->ne[2] * type_size;
+ result->nb[1] = result->ne[0] * result->nb[0];
+ result->nb[2] = type_size;
+ }
+
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_CONV_2D_DW;
+ result->src[0] = a;
+ result->src[1] = b;
+ return result;
+}
+
// ggml_conv_transpose_2d_p0
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {