]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Vulkan: add conv_transpose_2d operation (#16022)
authorShin-myoung-serp <redacted>
Mon, 22 Sep 2025 08:04:01 +0000 (17:04 +0900)
committerGitHub <redacted>
Mon, 22 Sep 2025 08:04:01 +0000 (10:04 +0200)
* Vulkan: add conv_transpose_2d operation

* Vulkan: fix typo in conv_transpose_2d shader(s0mp, s0L, s1mp, s1L)

* Vulkan: fix incorrect indentation in conv_transpose_2d shader

* Vulkan: add checking the push constants size limit and reuse conv2d_mm.comp for conv_transpose_2d operation

* Vulkan: revert the order of the index calculation and bound check in conv_2d shader

* Vulkan: explicity check push constants limit in supports_op() for conv_transpose_2d operation.

* Vulkan: remove unnecessary lower bound checks for H/W_idx in the conv_2d shader.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 5818f938e6bbfea4caa746c766d4f0af5b8d4430..0feaf4cb5d7c23383ff74f583a969aff5454bbab 100644 (file)
@@ -574,6 +574,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_opt_step_sgd_f32;
     vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
+    vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
+    vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
     vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
 
@@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
     init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
 }
 
+struct vk_op_conv_transpose_2d_push_constants {
+    uint32_t Cout;
+    uint32_t Cin;
+    uint32_t N;
+
+    uint32_t KW;
+    uint32_t KH;
+    uint32_t W;
+    uint32_t H;
+    uint32_t OW;
+    uint32_t OH;
+
+    uint32_t s0;
+    uint32_t s1;
+    uint32_t p0;
+    uint32_t p1;
+    uint32_t d0;
+    uint32_t d1;
+
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+
+    uint32_t nb1;
+    uint32_t nb2;
+    uint32_t nb3;
+
+    // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
+    uint32_t KWmp;   uint32_t KWL;
+    uint32_t KWKHmp; uint32_t KWKHL;
+    uint32_t OWmp;   uint32_t OWL;
+    uint32_t OWOHmp; uint32_t OWOHL;
+    uint32_t s0mp; uint32_t s0L;
+    uint32_t s1mp; uint32_t s1L;
+};
+
+template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
+    // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
+    init_fastdiv_values(p.KW,       p.KWmp,    p.KWL);
+    init_fastdiv_values(p.KW*p.KH,  p.KWKHmp,  p.KWKHL);
+    init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);
+    init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
+    init_fastdiv_values(p.s0,       p.s0mp,    p.s0L);
+    init_fastdiv_values(p.s1,       p.s1mp,    p.s1L);
+}
+
 struct vk_op_conv2d_dw_push_constants {
     uint32_t ne;
     uint32_t batches;
@@ -1322,7 +1374,7 @@ class vk_perf_logger {
             flops[name].push_back(m * n * (k + (k - 1)) * batch);
             return;
         }
-        if (node->op == GGML_OP_CONV_2D) {
+        if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
             std::string   name    = ggml_op_name(node->op);
             ggml_tensor * knl     = node->src[0];
             uint64_t      OW      = node->ne[0];
@@ -1331,7 +1383,7 @@ class vk_perf_logger {
             uint64_t      Cout    = node->ne[2];
             uint64_t      KW      = knl->ne[0];
             uint64_t      KH      = knl->ne[1];
-            uint64_t      Cin     = knl->ne[2];
+            uint64_t      Cin     = node->src[1]->ne[2];
             // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
             uint64_t      size_M  = Cout;
             uint64_t      size_K  = Cin * KW * KH;
@@ -3492,7 +3544,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    // conv2d
+    // conv2d, conv_transpose_2d
     for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
         uint32_t conv2d_WG_SIZE  = 256;
         uint32_t conv2d_BS_K     = 128;
@@ -3567,31 +3619,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
         std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
         std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
 
+#define CREATE_CONV(name, type_suffix, spv_suffix) \
+        ggml_vk_create_pipeline( \
+            device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
+            name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
+            sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+#define CREATE_CONVS(spv_suffix) \
+        CREATE_CONV(conv2d, _f32, spv_suffix) \
+        CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
+        if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
+            CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
+            CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
+        }
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
         if (device->coopmat2) {
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+            CREATE_CONVS(_cm2)
         } else
 #endif
         if (conv2d_UNROLL) {
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+            CREATE_CONVS(_unroll)
         } else {
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
-            ggml_vk_create_pipeline(
-                device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
-                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+            CREATE_CONVS( )
         }
+#undef CREATE_CONV
+#undef CREATE_CONVS
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -7548,6 +7599,33 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
     return elements;
 }
 
+static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
+
+    // src0 - kernel:   [KW, KH, Cout, Cin]
+    // src1 - input:    [W, H, Cin, N]
+    // dst - result:    [OW, OH, Cout, N]
+
+    auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+        return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
+    };
+    // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
+    int64_t W    = src1->ne[0];
+    int64_t H    = src1->ne[1];
+    int64_t KW   = src0->ne[0];
+    int64_t KH   = src0->ne[1];
+    int64_t Cout = src0->ne[2];
+    int64_t N    = src1->ne[3];
+    int64_t OH   = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
+    int64_t OW   = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
+    int64_t NPQ  = N * OW * OH;
+
+    // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
+    std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
+    return elements;
+}
+
 static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
     switch (op) {
     case GGML_OP_GET_ROWS:
@@ -7925,9 +8003,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         }
         return nullptr;
     case GGML_OP_CONV_2D:
+    case GGML_OP_CONV_TRANSPOSE_2D:
         if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
             ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
-            auto elements = ggml_vk_get_conv_elements(dst);
+            std::array<uint32_t, 3> elements;
+            if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
+            else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
             vk_conv_shapes shape;
 
             uint32_t tiles[CONV_SHAPE_COUNT];
@@ -7947,10 +8028,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 shape = CONV_SHAPE_64x32;
             }
 
-            if (src0->type == GGML_TYPE_F32) {
-                return ctx->device->pipeline_conv2d_f32[shape];
-            } else if (src0->type == GGML_TYPE_F16) {
-                return ctx->device->pipeline_conv2d_f16_f32[shape];
+            if (op == GGML_OP_CONV_2D) {
+                if (src0->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_conv2d_f32[shape];
+                } else if (src0->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_conv2d_f16_f32[shape];
+                }
+            } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
+                if (src0->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_conv_transpose_2d_f32[shape];
+                } else if (src0->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
+                }
             }
         }
         return nullptr;
@@ -8350,6 +8439,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         {
             elements = ggml_vk_get_conv_elements(dst);
         } break;
+    case GGML_OP_CONV_TRANSPOSE_2D:
+        {
+            elements = ggml_vk_get_conv_transpose_2d_elements(dst);
+        } break;
     case GGML_OP_ADD:
     case GGML_OP_SUB:
     case GGML_OP_DIV:
@@ -9523,6 +9616,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
     ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
 }
 
+static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
+                                      const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    vk_op_conv_transpose_2d_push_constants p{};
+    p.Cout = static_cast<uint32_t>(ne02);
+    p.Cin  = static_cast<uint32_t>(ne03);
+    p.N    = static_cast<uint32_t>(ne13);
+
+    p.KW = static_cast<uint32_t>(ne00);
+    p.KH = static_cast<uint32_t>(ne01);
+    p.W  = static_cast<uint32_t>(ne10);
+    p.H  = static_cast<uint32_t>(ne11);
+    p.OW = static_cast<uint32_t>(ne0);
+    p.OH = static_cast<uint32_t>(ne1);
+
+    p.s0 = static_cast<uint32_t>(dst->op_params[0]);
+    p.s1 = static_cast<uint32_t>(dst->op_params[0]);
+    p.p0 = 0;
+    p.p1 = 0;
+    p.d0 = 1;
+    p.d1 = 1;
+
+    p.nb01 = static_cast<uint32_t>(nb01 / nb00);
+    p.nb02 = static_cast<uint32_t>(nb02 / nb00);
+    p.nb03 = static_cast<uint32_t>(nb03 / nb00);
+
+    p.nb11 = static_cast<uint32_t>(nb11 / nb10);
+    p.nb12 = static_cast<uint32_t>(nb12 / nb10);
+    p.nb13 = static_cast<uint32_t>(nb13 / nb10);
+
+    p.nb1 = static_cast<uint32_t>(nb1 / nb0);
+    p.nb2 = static_cast<uint32_t>(nb2 / nb0);
+    p.nb3 = static_cast<uint32_t>(nb3 / nb0);
+
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne12);
+
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
+}
+
 static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     vk_op_conv2d_dw_push_constants p{};
     p.ne = ggml_nelements(dst);
@@ -10615,6 +10757,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_CONV_TRANSPOSE_1D:
     case GGML_OP_POOL_2D:
     case GGML_OP_CONV_2D:
+    case GGML_OP_CONV_TRANSPOSE_2D:
     case GGML_OP_CONV_2D_DW:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_RWKV_WKV7:
@@ -10686,6 +10829,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_POOL_2D:
         case GGML_OP_CONV_2D:
+        case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_OPT_STEP_SGD:
@@ -10997,6 +11141,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_CONV_2D:
         ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
 
+        break;
+    case GGML_OP_CONV_TRANSPOSE_2D:
+        ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun);
+
         break;
     case GGML_OP_CONV_2D_DW:
         ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -11137,6 +11285,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
     case GGML_OP_CONV_TRANSPOSE_1D:
     case GGML_OP_POOL_2D:
     case GGML_OP_CONV_2D:
+    case GGML_OP_CONV_TRANSPOSE_2D:
     case GGML_OP_CONV_2D_DW:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_RWKV_WKV7:
@@ -11794,10 +11943,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
         if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
             total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
-        } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
+        } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) {
             // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
             auto CRS_size =
-                cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
+                cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2];
             auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
             total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
         }
@@ -12618,10 +12767,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_CONV_TRANSPOSE_1D:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_CONV_2D:
+        case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
                 const vk_device& device = ggml_vk_get_device(ctx->device);
+                if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
+                    device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
+                    return false;
+                }
                 // Channel-contiguous format is not supported yet.
                 return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                     op->src[1]->type == GGML_TYPE_F32 &&
@@ -13240,6 +13394,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         const int32_t d0 = tensor->op_params[4];
         const int32_t d1 = tensor->op_params[5];
         tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
+    } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
+        const int32_t s = tensor->op_params[0];
+        tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
     } else if (tensor->op == GGML_OP_LEAKY_RELU) {
         const float * op_params = (const float *)tensor->op_params;
         tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
index 86bafba4a4398cc0d1446f3f63368ff4d3126df8..44a64ddc80f62c9724dc71873b6a76a1cb340a44 100644 (file)
@@ -16,7 +16,7 @@
 // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
 layout(binding = 0) readonly buffer A {
     A_TYPE knl_data[];
-};  // src0 - kernel:   [KW, KH, Cin, Cout]
+};  // src0 - kernel:   [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
 
 layout(binding = 1) readonly buffer B {
     B_TYPE src_data[];
@@ -66,6 +66,10 @@ layout(push_constant) uniform parameter {
     uint32_t KWKHmp; uint32_t KWKHL;
     uint32_t OWmp;   uint32_t OWL;
     uint32_t OWOHmp; uint32_t OWOHL;
+#ifdef TRANSPOSE
+    uint32_t s0mp; uint32_t s0L;
+    uint32_t s1mp; uint32_t s1L;
+#endif
 }
 
 p;
@@ -225,7 +229,11 @@ void main() {
             uint32_t B_ly    = r_offset + Ar;
             uint32_t B_lx    = Ac;
             uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
+#ifdef TRANSPOSE
+            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
+#else
             uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
+#endif
             float    val     = knl_data[knl_idx];
             if (K_idx >= K || CRS_idx_a >= CRS) {
                 val = 0.0;
@@ -267,12 +275,24 @@ void main() {
             KW_idx_b               = CRS_remainder - KH_idx_b * p.KW;
 #endif
 
+#ifdef TRANSPOSE
+            uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
+            uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
+            uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
+            uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
+#else
             uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
             uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
+#endif
             uint32_t src_idx =
                 min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
             float val = src_data[src_idx];
-            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
+            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
+                || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
+#ifdef TRANSPOSE
+                || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
+#endif
+                ) {
                 val = 0.0;
             }
             Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
index 86c873cc4af7df4284133a3be875ca63cc5d8d95..2531610e47b8980b2939561a488a8093d18483c6 100644 (file)
@@ -796,16 +796,26 @@ void process_shaders() {
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
     string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
-    string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
-    string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
-
-    string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
-    string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
-
+    for (auto transpose : {false, true}) {
+        for (auto unroll : {false, true}) {
+            for (auto a_f16 : {false, true}) {
+                std::map<std::string, std::string> defines = {
+                    {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
+                    {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""},
+                };
+                if (transpose) defines["TRANSPOSE"] = "1";
+                std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d")
+                    + (a_f16 ? "_f16" : "") + "_f32";
+                string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines);
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
-    string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
-    string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
+                if (unroll) {
+                    defines["COOPMAT2"] = "1";
+                    string_to_spv(name, "conv2d_mm.comp", defines, true, false, true);
+                }
 #endif
+            }
+        }
+    }
 
     string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
     string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
index ef6594ea5902b22c6794d3db1026e69ad09d2814..f11eecd8e71a5f86ec72141d83435725692555ca 100644 (file)
@@ -3969,6 +3969,10 @@ struct test_conv_transpose_2d : public test_case {
         return VARS_TO_STR3(ne_input, ne_kernel, stride);
     }
 
+    double max_nmse_err() override {
+        return 5e-4; // The default 1e-7 is too small for Vulkan.
+    }
+
     test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
                            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
                            int stride = 1)