vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
+ vk_pipeline pipeline_conv2d_dw_whcn_f32;
+ vk_pipeline pipeline_conv2d_dw_cwhn_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
uint32_t H;
};
+struct vk_op_conv2d_dw_push_constants {
+ uint32_t ne;
+ uint32_t batches;
+ uint32_t channels;
+ uint32_t dst_w;
+ uint32_t dst_h;
+ uint32_t src_w;
+ uint32_t src_h;
+ uint32_t knl_w;
+ uint32_t knl_h;
+ int32_t stride_x;
+ int32_t stride_y;
+ int32_t pad_x;
+ int32_t pad_y;
+ int32_t dilation_x;
+ int32_t dilation_y;
+};
+
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
+
for (auto &c : compiles) {
c.wait();
}
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
+ case GGML_OP_CONV_2D_DW:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ if (ggml_is_contiguous(src1)) {
+ return ctx->device->pipeline_conv2d_dw_whcn_f32;
+ } else if (ggml_is_contiguous_channels(src1)) {
+ return ctx->device->pipeline_conv2d_dw_cwhn_f32;
+ }
+ }
+ return nullptr;
default:
return nullptr;
}
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
+ case GGML_OP_CONV_2D_DW:
return true;
default:
return false;
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_UNARY:
+ case GGML_OP_CONV_2D_DW:
{
const uint32_t ne = ggml_nelements(dst);
if (ne > 262144) {
}, dryrun);
}
+static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+ vk_op_conv2d_dw_push_constants p{};
+ p.ne = ggml_nelements(dst);
+ p.channels = dst->ne[2];
+ p.batches = dst->ne[3];
+ p.dst_w = dst->ne[0];
+ p.dst_h = dst->ne[1];
+ p.src_w = src1->ne[0];
+ p.src_h = src1->ne[1];
+ p.knl_w = src0->ne[0];
+ p.knl_h = src0->ne[1];
+ p.stride_x = dst->op_params[0];
+ p.stride_y = dst->op_params[1];
+ p.pad_x = dst->op_params[2];
+ p.pad_y = dst->op_params[3];
+ p.dilation_x = dst->op_params[4];
+ p.dilation_y = dst->op_params[5];
+
+ GGML_ASSERT(src0->ne[3] == p.channels);
+ GGML_ASSERT(src1->ne[3] == p.batches);
+
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
+}
+
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
{
// These operations all go through ggml_vk_op_f32, so short-circuit and
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
+ break;
+ case GGML_OP_CONV_2D_DW:
+ ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
+
break;
case GGML_OP_LEAKY_RELU:
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_CONV_2D_DW:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
--- /dev/null
+#version 450
+
+#include "types.comp"
+
+layout (push_constant) uniform parameter
+{
+ uint ne;
+ uint batches;
+ uint channels;
+ uint dst_w;
+ uint dst_h;
+ uint src_w;
+ uint src_h;
+ uint knl_w;
+ uint knl_h;
+ int stride_x;
+ int stride_y;
+ int pad_x;
+ int pad_y;
+ int dilation_x;
+ int dilation_y;
+} p;
+
+layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
+layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
+layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
+ uint i0 = idx / p.dst_w;
+ uint dst_x = idx - i0 * p.dst_w;
+ uint i1 = i0 / p.dst_h;
+ uint dst_y = i0 - i1 * p.dst_h;
+ uint n = i1 / p.channels;
+ uint c = i1 - n * p.channels;
+
+ uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
+ uint knl_i = c * p.knl_h * p.knl_w;
+
+ FLOAT_TYPE sum = 0.0;
+ for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+ if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
+ continue;
+ }
+ for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+ if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
+ continue;
+ }
+ FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
+ FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
+ sum = fma(v, k, sum);
+ }
+ }
+ return sum;
+}
+
+FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
+ uint i0 = idx / p.channels;
+ uint c = idx - i0 * p.channels;
+ uint i1 = i0 / p.dst_w;
+ uint dst_x = i0 - i1 * p.dst_w;
+ uint n = i1 / p.dst_h;
+ uint dst_y = i1 - n * p.dst_h;
+
+ uint src_i = n * p.channels * p.src_h * p.src_w;
+ uint src_row = p.src_w * p.channels;
+ uint knl_row = p.knl_w * p.channels;
+
+ FLOAT_TYPE sum = 0.0;
+ for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+ if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
+ continue;
+ }
+ for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+ if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
+ continue;
+ }
+ FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
+ FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
+ sum = fma(v, k, sum);
+ }
+ }
+ return sum;
+}
+
+void main() {
+ uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+ if (idx >= p.ne) {
+ return;
+ }
+
+ FLOAT_TYPE result =
+#ifdef WHCN
+ conv_2d_dw_whcn(idx);
+#else
+ conv_2d_dw_cwhn(idx);
+#endif
+ dst_data[idx] = D_TYPE(result);
+}
+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+ string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
+ string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
+
for (auto &c : compiles) {
c.wait();
}
-f3a375f20bf56860b30e7c511d03593a1e393345
+0482de9c63b9134eb462c7732888c0ee0dbc2755
}
};
+// GGML_OP_CONV_2D_DW
+struct test_conv_2d_dw : public test_case {
+ const std::array<int64_t, 4> ne_input;
+ const std::array<int64_t, 4> ne_kernel;
+ const int stride;
+ const int padding;
+ const int dilation;
+ const bool cwhn;
+
+ std::string vars() override {
+ return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
+ }
+
+ test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},
+ std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},
+ int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
+ : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+ ggml_set_name(input, "input");
+
+ ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+ ggml_set_name(kernel, "kernel");
+
+ if (cwhn) {
+ // change memory layout to channel-most-contiguous (CWHN),
+ // then permute it back so NE matches the original input
+ input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
+ input = ggml_permute(ctx, input, 2, 0, 1, 3);
+ kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
+ kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
+ }
+
+ ggml_tensor * out = ggml_conv_2d_dw_direct(
+ ctx, kernel, input,
+ stride, stride, padding, padding, dilation, dilation);
+ ggml_set_name(out, "out");
+ return out;
+ }
+};
+
// GGML_OP_CONCAT
struct test_concat : public test_case {
const ggml_type type;
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
+ test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
+ test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
+ test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
+ test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
+
test_cases.emplace_back(new test_conv_transpose_1d());
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
}
}
+ test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
+ test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
+
return test_cases;
}