]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: adds CONV_2D op and direct GEMM Vulkan implementation (#14316)
authorErvin Áron Tasnádi <redacted>
Sat, 19 Jul 2025 19:59:08 +0000 (21:59 +0200)
committerGitHub <redacted>
Sat, 19 Jul 2025 19:59:08 +0000 (21:59 +0200)
* ggml/ggml-vulkan/test-backend-ops: adds CONV_2D for Vulkan

* ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly
with gemm (no need for im2col),

* test-backend-ops: adds test_case_ref to check the validity/performance of ops
against reference implementations having different graphs, adds tests

* * Performance fixes: minimized branch divergence, uses collectives to
  eliminate redundant calculation, macros removed.

* Kernel shared memory size check

* Updates test-backend-ops to support graphs for performance
  measurement.

* * Apple/Win32 compile errors fixed

* Subgroup size used to determine tile size -> fixes llvmpipe errors.

* Collectives disabled by default.

* Intel support is disabled as the performance is poor.

* Conv2d enabled for Intel with disabled collectives, disabled for Apple

* test-backend-ops modifications are reverted

* Trailing spaces and missing override fixed.

* Triggering pipeline relaunch.

* Code formatted with .clang-format.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 0707d71bb6c55a407febb84243672ff08bd52a0b..c3f1369b66315eaef2a7ca656f35453ba1efcc68 100644 (file)
@@ -483,6 +483,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_rwkv_wkv6_f32;
     vk_pipeline pipeline_rwkv_wkv7_f32;
     vk_pipeline pipeline_opt_step_adamw_f32;
+    vk_pipeline pipeline_conv2d_f32;
     vk_pipeline pipeline_conv2d_dw_whcn_f32;
     vk_pipeline pipeline_conv2d_dw_cwhn_f32;
 
@@ -876,6 +877,38 @@ struct vk_op_rwkv_wkv7_push_constants {
     uint32_t H;
 };
 
+struct vk_op_conv2d_push_constants {
+    uint32_t Cout;
+    uint32_t Cin;
+    uint32_t N;
+
+    uint32_t KW;
+    uint32_t KH;
+    uint32_t W;
+    uint32_t H;
+    uint32_t OW;
+    uint32_t OH;
+
+    uint32_t s0;
+    uint32_t s1;
+    uint32_t p0;
+    uint32_t p1;
+    uint32_t d0;
+    uint32_t d1;
+
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+
+    uint32_t nb1;
+    uint32_t nb2;
+    uint32_t nb3;
+};
+
 struct vk_op_conv2d_dw_push_constants {
     uint32_t ne;
     uint32_t batches;
@@ -975,18 +1008,45 @@ private:
 #endif // GGML_VULKAN_MEMORY_DEBUG
 
 class vk_perf_logger {
-public:
+  public:
     void print_timings() {
+        if (timings.empty()) {
+            return;
+        }
+        uint64_t total_all_op_times = 0;
         std::cerr << "----------------\nVulkan Timings:" << std::endl;
-        for (const auto& t : timings) {
-            uint64_t total = 0;
-            for (const auto& time : t.second) {
-                total += time;
+        for (const auto & t : timings) {
+            uint64_t total_op_times = 0;
+            for (const auto & time : t.second) {
+                total_op_times += time;
+            }
+            std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
+                      << " us";
+
+            // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
+            auto it = flops.find(t.first);
+            if (it != flops.end() && (it->second).size() == t.second.size()) {
+                uint64_t total_op_flops = 0;
+                for (const auto & elem : it->second) {
+                    total_op_flops += elem;
+                }
+                std::cerr << " ("
+                          << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
+                                 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
+                          << " GFLOPS/s)";
             }
-            std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
+
+            total_all_op_times += total_op_times;
+
+            std::cerr << std::endl;
+        }
+
+        if (timings.size() > 0) {
+            std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
         }
 
         timings.clear();
+        flops.clear();
     }
 
     void log_timing(const ggml_tensor * node, uint64_t time) {
@@ -995,22 +1055,45 @@ public:
             return;
         }
         if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
-            const uint64_t m = node->src[0]->ne[1];
-            const uint64_t n = node->src[1]->ne[1];
-            const uint64_t k = node->src[1]->ne[0];
-            std::string name = ggml_op_name(node->op);
+            const uint64_t m    = node->src[0]->ne[1];
+            const uint64_t n    = node->src[1]->ne[1];
+            const uint64_t k    = node->src[1]->ne[0];
+            std::string    name = ggml_op_name(node->op);
             if (n == 1) {
                 name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
             } else {
                 name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
             }
             timings[name].push_back(time);
+            flops[name].push_back(m * n * (k + (k - 1)));
+            return;
+        }
+        if (node->op == GGML_OP_CONV_2D) {
+            std::string   name    = ggml_op_name(node->op);
+            ggml_tensor * knl     = node->src[0];
+            uint64_t      OW      = node->ne[0];
+            uint64_t      OH      = node->ne[1];
+            uint64_t      N       = node->ne[3];
+            uint64_t      Cout    = node->ne[2];
+            uint64_t      KW      = knl->ne[0];
+            uint64_t      KH      = knl->ne[1];
+            uint64_t      Cin     = knl->ne[2];
+            // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
+            uint64_t      size_M  = Cout;
+            uint64_t      size_K  = Cin * KW * KH;
+            uint64_t      size_N  = N * OW * OH;
+            uint64_t      n_flops = size_M * size_N * (size_K + (size_K - 1));
+            name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
+                    ", N=N*OW*OH=" + std::to_string(size_N);
+            flops[name].push_back(n_flops);
+            timings[name].push_back(time);
             return;
         }
         timings[ggml_op_name(node->op)].push_back(time);
     }
-private:
+  private:
     std::map<std::string, std::vector<uint64_t>> timings;
+    std::map<std::string, std::vector<uint64_t>> flops;
 };
 
 struct ggml_backend_vk_context {
@@ -2113,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             }
             compile_count++;
         }
+
         compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
                                       parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
     };
@@ -2962,6 +3046,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
+    // conv2d
+    uint32_t conv2d_WG_SIZE  = 256;
+    uint32_t conv2d_BS_K     = 128;
+    uint32_t conv2d_BS_CRS   = 16;
+    uint32_t use_collectives = 0;  // Enables subgroup ops for preventing the re-calculation of indices.
+    if (device->subgroup_shuffle &&
+        device->vendor_id != VK_VENDOR_ID_INTEL) {  // Do not enable collectives on Intel, see PR 14316
+        use_collectives = 1;
+        conv2d_BS_CRS   = std::min(
+            device->subgroup_size,
+            conv2d_BS_CRS);  // CRS block size should be capped at sugroup size for correctness when shuffle is used.
+    }
+    uint32_t conv2d_BS_NPQ = 128;
+    uint32_t conv2d_TS_K   = 8;
+    uint32_t conv2d_shmem_req =
+        (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
+    if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
+        conv2d_BS_CRS = 8;
+        if (use_collectives) {
+            conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
+        }
+    }
+
+    if (use_collectives) {
+        ggml_vk_create_pipeline(
+            device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
+            sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
+            { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
+    } else {
+        ggml_vk_create_pipeline(
+            device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
+            sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
+            { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
+            false);
+    }
+
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 
@@ -6837,6 +6957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_leaky_relu_f32;
         }
         return nullptr;
+    case GGML_OP_CONV_2D:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
+            ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+            return ctx->device->pipeline_conv2d_f32;
+        }
+        return nullptr;
     case GGML_OP_CONV_2D_DW:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             if (ggml_is_contiguous(src1)) {
@@ -7159,6 +7285,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             const uint32_t OW = dst->ne[0];
             elements = { N * OC * OH * OW, 1, 1};
         } break;
+    case GGML_OP_CONV_2D:
+        {
+            // src0 - kernel:   [KW, KH, Cin, Cout]
+            // src1 - input:    [W, H, Cin, N]
+            // dst - result:    [OW, OH, Cout, N]
+
+            // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
+            auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+                return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+            };
+            // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
+            int64_t W    = src1->ne[0];
+            int64_t H    = src1->ne[1];
+            int64_t KW   = src0->ne[0];
+            int64_t KH   = src0->ne[1];
+            int64_t Cout = src0->ne[3];
+            int64_t N    = src1->ne[3];
+            int64_t OH   = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
+            int64_t OW   = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
+            int64_t NPQ  = N * OW * OH;
+
+            // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
+            elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
+        }
+        break;
     case GGML_OP_ADD:
     case GGML_OP_SUB:
     case GGML_OP_DIV:
@@ -8025,6 +8176,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
     }, dryrun);
 }
 
+static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
+                            const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    vk_op_conv2d_push_constants p{};
+    p.Cout = static_cast<uint32_t>(ne03);
+    p.Cin  = static_cast<uint32_t>(ne02);
+    p.N    = static_cast<uint32_t>(ne13);
+
+    p.KW = static_cast<uint32_t>(ne00);
+    p.KH = static_cast<uint32_t>(ne01);
+    p.W  = static_cast<uint32_t>(ne10);
+    p.H  = static_cast<uint32_t>(ne11);
+    p.OW = static_cast<uint32_t>(ne0);
+    p.OH = static_cast<uint32_t>(ne1);
+
+    p.s0 = static_cast<uint32_t>(dst->op_params[0]);
+    p.s1 = static_cast<uint32_t>(dst->op_params[1]);
+    p.p0 = static_cast<uint32_t>(dst->op_params[2]);
+    p.p1 = static_cast<uint32_t>(dst->op_params[3]);
+    p.d0 = static_cast<uint32_t>(dst->op_params[4]);
+    p.d1 = static_cast<uint32_t>(dst->op_params[5]);
+
+    p.nb01 = static_cast<uint32_t>(nb01 / nb00);
+    p.nb02 = static_cast<uint32_t>(nb02 / nb00);
+    p.nb03 = static_cast<uint32_t>(nb03 / nb00);
+
+    p.nb11 = static_cast<uint32_t>(nb11 / nb10);
+    p.nb12 = static_cast<uint32_t>(nb12 / nb10);
+    p.nb13 = static_cast<uint32_t>(nb13 / nb10);
+
+    p.nb1 = static_cast<uint32_t>(nb1 / nb0);
+    p.nb2 = static_cast<uint32_t>(nb2 / nb0);
+    p.nb3 = static_cast<uint32_t>(nb3 / nb0);
+
+    GGML_ASSERT(ne03 == ne2);
+    GGML_ASSERT(ne02 == ne12);
+
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
+}
+
 static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     vk_op_conv2d_dw_push_constants p{};
     p.ne = ggml_nelements(dst);
@@ -9087,6 +9287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_CONV_TRANSPOSE_1D:
     case GGML_OP_POOL_2D:
+    case GGML_OP_CONV_2D:
     case GGML_OP_CONV_2D_DW:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_RWKV_WKV7:
@@ -9154,6 +9355,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_POOL_2D:
+        case GGML_OP_CONV_2D:
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_LEAKY_RELU:
             {
@@ -9360,6 +9562,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_POOL_2D:
         ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
 
+        break;
+    case GGML_OP_CONV_2D:
+        ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
+
         break;
     case GGML_OP_CONV_2D_DW:
         ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9490,6 +9696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_CONV_TRANSPOSE_1D:
     case GGML_OP_POOL_2D:
+    case GGML_OP_CONV_2D:
     case GGML_OP_CONV_2D_DW:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_RWKV_WKV7:
@@ -10071,6 +10278,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
         if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
             total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
+        } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
+            // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
+            auto CRS_size =
+                cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
+            auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
+            total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
         }
         i += ctx->num_additional_fused_ops;
         ctx->num_additional_fused_ops = 0;
@@ -10647,6 +10860,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return true;
         case GGML_OP_CONV_TRANSPOSE_1D:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
+        case GGML_OP_CONV_2D:
+            {
+                // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
+                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+                const vk_device& device = ggml_vk_get_device(ctx->device);
+                bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
+                // Channel-contiguous format is not supported yet.
+                return (op->src[0]->type == GGML_TYPE_F32 &&
+                    op->src[1]->type == GGML_TYPE_F32 &&
+                    op->type == GGML_TYPE_F32 &&
+                    ggml_is_contiguous(op->src[0]) &&
+                    ggml_is_contiguous(op->src[1]) &&
+                    ggml_is_contiguous(op)) && !is_Apple;
+            }
         default:
             return false;
     }
@@ -11205,6 +11432,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         const int32_t p1 = tensor->op_params[6];
 
         tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
+    } else if (tensor->op == GGML_OP_CONV_2D) {
+        const int32_t s0 = tensor->op_params[0];
+        const int32_t s1 = tensor->op_params[1];
+        const int32_t p0 = tensor->op_params[2];
+        const int32_t p1 = tensor->op_params[3];
+        const int32_t d0 = tensor->op_params[4];
+        const int32_t d1 = tensor->op_params[5];
+        tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
     } else if (tensor->op == GGML_OP_LEAKY_RELU) {
         const float * op_params = (const float *)tensor->op_params;
         tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
new file mode 100644 (file)
index 0000000..481940a
--- /dev/null
@@ -0,0 +1,265 @@
+#version 450
+
+#ifdef USE_COLLECTIVES
+#    extension GL_KHR_shader_subgroup_shuffle : enable
+#endif
+
+#include "types.comp"
+
+// Make spec constant
+#define SHMEM_PAD 0
+
+// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
+layout(binding = 0) readonly buffer A {
+    A_TYPE knl_data[];
+};  // src0 - kernel:   [KW, KH, Cin, Cout]
+
+layout(binding = 1) readonly buffer B {
+    B_TYPE src_data[];
+};  // src1 - input:    [W, H, Cin, N] -- channel_first format
+
+layout(binding = 2) writeonly buffer D {
+    D_TYPE dst_data[];
+};  // dst - result:    [OW, OH, Cout, N]
+
+layout(push_constant) uniform parameter {
+    // I/O channels, batch size
+    uint32_t Cout;
+    uint32_t Cin;
+    uint32_t N;
+
+    // Tensor spatial sizes: kernel, input, output
+    uint32_t KW;
+    uint32_t KH;
+    uint32_t W;
+    uint32_t H;
+    uint32_t OW;
+    uint32_t OH;
+
+    // Parameters: stride, padding, dilation - 0=y, 1=x
+    uint32_t s0;
+    uint32_t s1;
+    uint32_t p0;
+    uint32_t p1;
+    uint32_t d0;
+    uint32_t d1;
+
+    // Strides in elements
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+
+    uint32_t nb1;
+    uint32_t nb2;
+    uint32_t nb3;
+}
+
+p;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+// Blocktile sizes
+layout(constant_id = 1) const uint BS_K            = 128;
+layout(constant_id = 2) const uint BS_CRS          = 16;
+layout(constant_id = 3) const uint BS_NPQ          = 128;
+// Thread-tile sizes
+layout(constant_id = 4) const uint TS_K            = 8;
+layout(constant_id = 5) const uint use_collectives = 1;
+
+uint32_t       tid     = gl_LocalInvocationID.x;
+const uint32_t WG_SIZE = gl_WorkGroupSize.x;
+
+uint splitWork(uint work_size, uint block_size) {
+    return (block_size + work_size - 1) / block_size;
+}
+
+uint32_t K   = p.Cout;
+uint32_t CRS = p.Cin * p.KH * p.KW;
+uint32_t NPQ = p.N * p.OH * p.OW;
+
+uint32_t n_elems_out = K * NPQ;
+
+// Number of blocktiles per input
+uint32_t NB_CRS = splitWork(CRS, BS_CRS);
+
+const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
+const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
+
+const uint32_t Ash_numel = BS_K * BS_CRS;
+const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
+
+const uint32_t Ash_len = BS_K * Ash_stride;
+const uint32_t Bsh_len = BS_CRS * Bsh_stride;
+
+shared float Ash[Ash_len];  // K x CRS
+shared float Bsh[Bsh_len];  // CRS x NPQ
+
+// Threadtile sizes
+const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
+
+// Number of threadtiles per blocktile
+const uint32_t NT_K   = BS_K / TS_K;
+const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
+
+float regA[TS_K];
+float regB[TS_NPQ];
+float regC[TS_K][TS_NPQ];
+
+/*
+Compute
+KxCRS @ CRSxNPQ = K x NPQ
+K=Cout
+C=Cin
+R,S=KH,KW
+P,Q=OH,OW
+*/
+
+uint32_t B_idx_K   = gl_WorkGroupID.x;
+uint32_t B_idx_NPQ = gl_WorkGroupID.y;
+
+uint32_t T_y = tid / NT_NPQ;
+uint32_t T_x = tid % NT_NPQ;
+
+uint32_t       Ar    = tid / BS_CRS;
+uint32_t       Ac    = tid % BS_CRS;
+const uint32_t ArpWg = WG_SIZE / BS_CRS;
+
+uint32_t       Br    = tid / BS_NPQ;
+uint32_t       Bc    = tid % BS_NPQ;
+const uint32_t BrpWg = WG_SIZE / BS_NPQ;
+
+void main() {
+    for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+        for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+            regC[T_ly][T_lx] = 0.0;
+        }
+    }
+    /* Advance block in CRS dim */
+    for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
+        uint32_t CRS_idx_a;
+        uint32_t Cin_idx_a;
+        uint32_t KH_idx_a;
+        uint32_t KW_idx_a;
+
+#ifdef USE_COLLECTIVES
+        uint32_t cached_CRS_idx;
+        uint32_t cached_Cin_idx;
+        uint32_t cached_KH_idx;
+        uint32_t cached_KW_idx;
+        if (use_collectives == 1) {
+            cached_CRS_idx                = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
+            cached_Cin_idx                = cached_CRS_idx / (p.KW * p.KH);
+            uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
+            cached_KH_idx                 = cached_CRS_remainder / p.KW;
+            cached_KW_idx                 = cached_CRS_remainder - cached_KH_idx * p.KW;
+
+            CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
+            Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
+            KH_idx_a  = subgroupShuffle(cached_KH_idx, Ac);
+            KW_idx_a  = subgroupShuffle(cached_KW_idx, Ac);
+        } else {
+            CRS_idx_a              = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
+            Cin_idx_a              = CRS_idx_a / (p.KW * p.KH);
+            uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
+            KH_idx_a               = CRS_remainder / p.KW;
+            KW_idx_a               = CRS_remainder - KH_idx_a * p.KW;
+        }
+#else
+        CRS_idx_a     = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
+        Cin_idx_a     = CRS_idx_a / (p.KW * p.KH);
+        CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
+        KH_idx_a      = CRS_remainder / p.KW;
+        KW_idx_a      = CRS_remainder - KH_idx_a * p.KW;
+#endif
+
+        /* Load kernel to A_block: (BS_K x BS_CRS)*/
+        for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
+            uint32_t B_ly    = r_offset + Ar;
+            uint32_t B_lx    = Ac;
+            uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
+            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
+            float    val     = knl_data[knl_idx];
+            if (K_idx >= K || CRS_idx_a >= CRS) {
+                val = 0.0;
+            }
+            Ash[B_ly * Ash_stride + B_lx] = val;
+        }
+        /* Load input to B_block: (BS_CRS x BS_NPQ) */
+        for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
+            uint32_t B_ly          = r_offset + Br;             /* Row index of B block */
+            uint32_t B_lx          = Bc;
+            uint32_t NPQ_idx       = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
+            uint32_t N_idx         = NPQ_idx / (p.OH * p.OW);
+            uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
+            uint32_t OH_idx        = NPQ_remainder / p.OW;
+            uint32_t OW_idx        = NPQ_remainder - OH_idx * p.OW;
+
+            uint32_t CRS_idx_b;
+            uint32_t Cin_idx_b;
+            uint32_t KH_idx_b;
+            uint32_t KW_idx_b;
+#ifdef USE_COLLECTIVES
+            if (use_collectives == 1) {
+                CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
+                Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
+                KH_idx_b  = subgroupShuffle(cached_KH_idx, r_offset + Br);
+                KW_idx_b  = subgroupShuffle(cached_KW_idx, r_offset + Br);
+            } else {
+                CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
+                Cin_idx_b              = CRS_idx_b / (p.KW * p.KH);
+                uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
+                KH_idx_b               = CRS_remainder / p.KW;
+                KW_idx_b               = CRS_remainder - KH_idx_b * p.KW;
+            }
+#else
+            CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
+            Cin_idx_b              = CRS_idx_b / (p.KW * p.KH);
+            uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
+            KH_idx_b               = CRS_remainder / p.KW;
+            KW_idx_b               = CRS_remainder - KH_idx_b * p.KW;
+#endif
+
+            uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
+            uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
+            uint32_t src_idx =
+                min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
+            float val = src_data[src_idx];
+            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
+                val = 0.0;
+            }
+            Bsh[B_ly * Bsh_stride + B_lx] = val;
+        }
+        barrier();
+        for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
+            for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+                regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
+            }
+            for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+                regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
+            }
+            for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+                for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+                    regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
+                }
+            }
+        }
+        barrier();
+    }
+    /* Save C* */
+    for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+        for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+            uint32_t K_idx   = B_idx_K * BS_K + T_y * TS_K + T_ly;
+            uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
+            uint32_t N_idx   = NPQ_idx / (p.OH * p.OW);
+            uint32_t OH_idx  = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
+            uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
+            uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
+            if (K_idx < K && NPQ_idx < NPQ) {
+                dst_data[dst_idx] = regC[T_ly][T_lx];
+            }
+        }
+    }
+}
index b1457583a4b598f6d127d793e205563e1c153898..598f0370fb8710d10cb04b14b6a51b7e469be178 100644 (file)
@@ -655,6 +655,8 @@ void process_shaders() {
 
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
+    string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
+
     string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
     string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
 
index bc732df5bb454c417ac0c79e8e32b28f40fe7ddf..731b4980af9473072fc10f0dcd4244708fbb8a73 100644 (file)
@@ -3699,6 +3699,93 @@ struct test_im2col : public test_case {
     }
 };
 
+// CONV_2D
+struct test_conv_2d : public test_case {
+    const std::array<int64_t, 4> ne_input;
+    const std::array<int64_t, 4> ne_kernel;
+    const int                    stride0;
+    const int                    stride1;
+    const int                    padding0;
+    const int                    padding1;
+    const int                    dilation0;
+    const int                    dilation1;
+    // Whether the inputs are contiguous in the channel dim or the width dim
+    const bool                   cwhn;
+
+    // If true, the direct CONV_2D will be used in the graph, otherwise it
+    // uses ggml_conv_2d:
+    // * if the program is called with -o CONV_2D_DIRECT_IMPL, the
+    // CONV_2D graph will be built, while
+    // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the
+    // IM2COL -> MUL_MM graph will be built.
+
+    std::string vars() override {
+        return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        // Just counting matmul costs:
+        // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops
+
+        // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
+        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+        };
+
+        int64_t W    = ne_input[0];
+        int64_t H    = ne_input[1];
+        int64_t KW   = ne_kernel[0];
+        int64_t KH   = ne_kernel[1];
+        int64_t Cin  = ne_kernel[2];
+        int64_t Cout = ne_kernel[3];
+        int64_t N    = ne_input[3];
+        int64_t OH   = calc_conv_output_size(H, KH, stride0, padding0, dilation0);
+        int64_t OW   = calc_conv_output_size(W, KW, stride0, padding0, dilation0);
+
+        int64_t K   = Cout;
+        int64_t CRS = Cin * KH * KW;
+        int64_t NPQ = N * OH * OW;
+
+        return K * NPQ * (2 * CRS - 1);
+    }
+
+    test_conv_2d(std::array<int64_t, 4> ne_input  = { 64, 64, 16, 1 },
+                 std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, int stride0 = 1, int stride1 = 1, int padding0 = 0,
+                 int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
+        ne_input(ne_input),
+        ne_kernel(ne_kernel),
+        stride0(stride0),
+        stride1(stride1),
+        padding0(padding0),
+        padding1(padding1),
+        dilation0(dilation0),
+        dilation1(dilation1),
+        cwhn(cwhn) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        if (cwhn) {
+            // change memory layout to channel-most-contiguous (CWHN),
+            // then permute it back so NE matches the original input
+            input  = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
+            input  = ggml_permute(ctx, input, 2, 0, 1, 3);
+            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
+            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
+        }
+
+        ggml_tensor * out =
+            ggml_conv_2d_direct(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);
+        ggml_set_name(out, "out");
+        return out;
+    }
+};
+
 // GGML_OP_CONV_2D_DW
 struct test_conv_2d_dw : public test_case {
     const std::array<int64_t, 4> ne_input;
@@ -5007,6 +5094,80 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
 
+// Conv_2D test cases
+#ifdef DETAILED_TESTS
+    // Probably we do not have enough time to execute these in the pipeline.
+    uint32_t iwh_idx  = 0;
+    uint32_t kwh_idx  = 1;
+    uint32_t Cout_idx = 2;
+    uint32_t Cin_idx  = 3;
+    uint32_t B_idx    = 4;
+
+    std::vector<std::array<int, 5>> cases = {
+  //{IWH, KWH, Cout, Cin, B}
+  // K=CRS=NPQ=4096 conv_2d matmul performance
+        {19,   4, 4096, 256, 16},
+ // K=128, CRS=128, NPQ=4096
+        { 19,  4, 128,  8,   16},
+ // K=130, CRS=128, NPQ=4096
+        { 19,  4, 130,  8,   16},
+ // Edge case: K x CRS is small
+        { 19,  2, 4,    4,   16},
+ // A ConvNet's first layer
+        { 224, 3, 8,    3,   1 },
+ // A ConvNet's first layer with 2x2 convolution, and 1 channel
+        { 224, 2, 8,    1,   1 },
+ // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch
+        { 224, 2, 8,    1,   8 },
+ // A middle layer of a ConvNet
+        { 58,  3, 64,   32,  1 },
+ // A middle layer of a ConvNet, several images in the batch
+        { 58,  3, 64,   32,  8 },
+ // A deep layer of a ConvNet, several images in the batch
+        { 16,  3, 256,  128, 8 }
+    };
+
+    for (auto act_case : cases) {
+        test_cases.emplace_back(new test_conv_2d(
+            { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
+            { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false));
+    }
+#endif
+
+    // CONV_2D:
+    auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+    };
+
+    //uint32_t s0 = 3;
+    uint32_t s1 = 5;
+    uint32_t p0 = 5;
+    //uint32_t p1 = 2;
+    uint32_t d0 = 2;
+    uint32_t d1 = 4;
+
+    for (uint32_t s0 : { 1, 3 }) {
+        for (uint32_t p1 : { 2, 5 }) {
+            for (uint32_t Cin : { 1, 25 }) {
+                for (uint32_t Cout : { 1, 12 }) {
+                    for (uint32_t KH : { 1, 2, 3, 11 }) {
+                        for (uint32_t KW : { 1, 2, 3, 11 }) {
+                            for (uint32_t H : { 1, 133 }) {
+                                for (uint32_t W : { 1, 141 }) {
+                                    if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 &&
+                                        calc_conv_output_size(H, KH, s1, p1, d1) > 0) {
+                                        test_cases.emplace_back(new test_conv_2d(
+                                            { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false));
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     // sycl backend will limit task global_range < MAX_INT
     // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
     // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
@@ -5610,6 +5771,43 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     std::vector<std::unique_ptr<test_case>> test_cases;
 
+    // Conv2d: K=CRS=NPQ=4096 matmul performance
+    uint32_t                        iwh_idx  = 0;
+    uint32_t                        kwh_idx  = 1;
+    uint32_t                        Cout_idx = 2;
+    uint32_t                        Cin_idx  = 3;
+    uint32_t                        B_idx    = 4;
+    std::vector<std::array<int, 5>> cases    = {
+  //{IWH, KWH, Cout, Cin, B}
+  // K=CRS=NPQ=4096 conv2d matmul performance
+        {19,   4, 4096, 256, 16},
+ // K=128, CRS=128, NPQ=4096
+        { 19,  4, 128,  8,   16},
+ // K=130, CRS=128, NPQ=4096
+        { 19,  4, 130,  8,   16},
+ // Edge case: K x CRS is small
+        { 19,  2, 4,    4,   16},
+ // A ConvNet's first layer
+        { 224, 3, 8,    3,   1 },
+ // A ConvNet's first layer with 2x2 convolution, and 1 channel
+        { 224, 2, 8,    1,   1 },
+ // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch
+        { 224, 2, 8,    1,   8 },
+ // A middle layer of a ConvNet
+        { 58,  3, 64,   32,  1 },
+ // A middle layer of a ConvNet, several images in the batch
+        { 58,  3, 64,   32,  8 },
+ // A deep layer of a ConvNet, several images in the batch
+        { 16,  3, 512,  128, 8 },
+    };
+
+    for (auto act_case : cases) {
+        // Direct CONV_2D
+        test_cases.emplace_back(new test_conv_2d(
+            { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
+            { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false));
+    }
+
     test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
     test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));