--- /dev/null
+#include "conv2d.cuh"
+
+struct conv_params {
+ const int64_t IW, IH;
+ const int64_t OW, OH;
+ const int64_t KW, KH;
+ const int64_t ST_X, ST_Y;
+ const int64_t PD_X, PD_Y;
+ const int64_t DL_X, DL_Y;
+ const int64_t IC, OC;
+ const int64_t B;
+ const int64_t TOTAL;
+};
+
+struct kernel_bounds {
+ int64_t y_min, y_max;
+ int64_t x_min, x_max;
+};
+
+__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
+ return (a > b) ? a : b;
+}
+
+__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
+ return (a < b) ? a : b;
+}
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
+ kernel_bounds bounds;
+ bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
+ int64_t kern_coord,
+ int64_t stride,
+ int64_t dilation,
+ int64_t padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
+ }
+
+ __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
+ return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
+ }
+
+ __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
+ }
+
+ __device__ static void unpack_indices(int64_t global_idx,
+ const conv_params & P,
+ int64_t & n,
+ int64_t & c,
+ int64_t & out_y,
+ int64_t & out_x) {
+ out_x = global_idx % P.OW;
+ out_y = (global_idx / P.OW) % P.OH;
+ c = (global_idx / (P.OW * P.OH)) % P.OC;
+ n = global_idx / (P.OW * P.OH * P.OC);
+ }
+};
+
+template <typename T, typename Layout>
+static __global__ void conv2d_kernel(const float * __restrict__ input,
+ const T * __restrict__ kernel,
+ float * __restrict__ output,
+ const conv_params P) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= P.TOTAL) {
+ return;
+ }
+
+ int64_t n, c_out, out_y, out_x;
+ Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
+
+ T acc = 0;
+
+ for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
+ kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
+
+ for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
+ const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
+
+ for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
+ const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
+
+ T input_val;
+ if (std::is_same<T, half>::value) {
+ input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
+ } else {
+ input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+ }
+
+ T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+ acc += (input_val * kernel_val);
+ }
+ }
+ }
+
+ // [N, OC, OH, OW]
+ output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
+}
+
+template <typename T>
+static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
+ conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
+}
+
+static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
+}
+
+static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
+}
+
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+ float * K_D = (float *) kernel->data;
+ const float * X_D = (const float *) input->data;
+ float * Y_D = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+
+ // same number of input channels
+ GGML_ASSERT(input->ne[2] == kernel->ne[2]);
+
+ cudaStream_t st = ctx.stream();
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int ST_X = p[0]; // stride_x
+ const int ST_Y = p[1]; // stride_y
+ const int PD_X = p[2]; // padding_x
+ const int PD_Y = p[3]; // padding_y
+ const int DL_X = p[4]; // dilation_x
+ const int DL_Y = p[5]; // dilation_y
+
+ // No cwhn
+ GGML_ASSERT(p[6] == false);
+
+ const int IW = input->ne[0]; // input_w
+ const int IH = input->ne[1]; // input_h
+ const int OW = dst->ne[0]; // output_w
+ const int OH = dst->ne[1]; // output_h
+ const int KW = kernel->ne[0]; // kernel_w
+ const int KH = kernel->ne[1]; // kernel_h
+ const int IC = input->ne[2]; // input_channels
+ const int OC = kernel->ne[3]; // ouptut_chanles
+ const int B = input->ne[3]; // n_batches
+
+ const int64_t total = B * OC * OH * OW;
+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
+
+ if (kernel->type == GGML_TYPE_F16) {
+ conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
+ } else {
+ conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
+ }
+}