]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: add conv2d kernel (llama/14403)
authorrmatif <redacted>
Mon, 21 Jul 2025 17:03:19 +0000 (19:03 +0200)
committerGeorgi Gerganov <redacted>
Thu, 24 Jul 2025 17:57:40 +0000 (20:57 +0300)
* add conv2d kernel

* fix trailing whitespace

* whitespace fixe

* handle f16 input and f16 kernel, more opt

* resolve conflicts

* use enqueue_ndrange_kernel

src/ggml-opencl/CMakeLists.txt
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/conv2d.cl [new file with mode: 0644]
src/ggml-opencl/kernels/conv2d_f16_f32.cl [new file with mode: 0644]

index ec5d8cf59556b9a88020edbc665c3c3a5dc180b6..015fa8f06824e52ea451d3ceda7ee174a435d57b 100644 (file)
@@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS
     pad
     repeat
     mul_mat_f16_f32
+    conv2d
+    conv2d_f16_f32
 )
 
 foreach (K ${GGML_OPENCL_KERNELS})
index 3388259152b4617acf1727f27f18b3849a1377d0..a31483b61085a0dabadd6492ed1c5a4c66d53ca5 100644 (file)
@@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
     cl_program program_tanh;
     cl_program program_upscale;
     cl_program program_concat;
+    cl_program program_conv_2d_f16;
+    cl_program program_conv_2d_f32;
+    cl_program program_conv_2d_f16_f32;
     cl_program program_tsembd;
     cl_program program_mul_mv_id_q4_0_f32_8x_flat;
 
@@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_upscale_bilinear;
     cl_kernel kernel_concat_f32_contiguous;
     cl_kernel kernel_concat_f32_non_contiguous;
+    cl_kernel kernel_conv_2d_f16;
+    cl_kernel kernel_conv_2d_f32;
+    cl_kernel kernel_conv_2d_f16_f32;
     cl_kernel kernel_timestep_embedding;
     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
 
@@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+     // conv2d
+     {
+        #ifdef GGML_OPENCL_EMBED_KERNELS
+                const std::string kernel_src {
+                    #include "conv2d.cl.h"
+                };
+                const std::string kernel_src_f16_f32 {
+                    #include "conv2d_f16_f32.cl.h"
+                };
+        #else
+                const std::string kernel_src = read_file("conv2d.cl");
+                const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
+        #endif
+                if (!kernel_src.empty()) {
+                    backend_ctx->program_conv_2d_f16 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                    backend_ctx->program_conv_2d_f32 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                } else {
+                    GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
+                    backend_ctx->program_conv_2d_f16 = nullptr;
+                    backend_ctx->kernel_conv_2d_f16 = nullptr;
+                    backend_ctx->program_conv_2d_f32 = nullptr;
+                    backend_ctx->kernel_conv_2d_f32 = nullptr;
+                }
+                if (!kernel_src_f16_f32.empty()) {
+                    backend_ctx->program_conv_2d_f16_f32 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                } else {
+                    GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
+                    backend_ctx->program_conv_2d_f16_f32 = nullptr;
+                    backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
+                }
+    }
+
     // mul_mv_id_q4_0_f32_8x_flat
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                    op->src[0]->ne[3] == 1 && op->ne[3] == 1;
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+        case GGML_OP_CONV_2D:
+            return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
+                   (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
+                   (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
         case GGML_OP_CONCAT:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
     backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_TENSOR_BINARY_OP_LOCALS;
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
+    const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
+
+    const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
+    const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
+    const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
+
+    const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
+    const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
+    const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
+
+    const int64_t NPQ = (int64_t)N * OW * OH;
+
+    const uint32_t BS_K = 64;
+    const uint32_t BS_NPQ = 64;
+    const uint32_t BS_CRS = 16;
+    const uint32_t VEC_SIZE = 4;
+
+    const uint32_t TS_K = 4;
+    const uint32_t TS_NPQ = 8;
+
+    const uint32_t WG_K = BS_K / TS_K;
+    const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
+
+    auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
+    const uint32_t NB_K = splitWork(Cout, BS_K);
+    const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
+
+    cl_kernel kernel;
+    size_t shmem_size;
+
+    if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        kernel = backend_ctx->kernel_conv_2d_f16;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        kernel = backend_ctx->kernel_conv_2d_f32;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        kernel = backend_ctx->kernel_conv_2d_f16_f32;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+    } else {
+        GGML_ASSERT(false && "Unsupported data type combination for conv2d");
+        return;
+    }
+
+    cl_uint idx = 0;
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
+
+    size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
+    size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             ggml_cl_upscale(backend, tensor->src[0], tensor);
             return true;
+        case GGML_OP_CONV_2D:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_conv_2d;
+            break;
         case GGML_OP_CONCAT:
             if (!any_on_device) {
                 return false;
diff --git a/src/ggml-opencl/kernels/conv2d.cl b/src/ggml-opencl/kernels/conv2d.cl
new file mode 100644 (file)
index 0000000..e339c90
--- /dev/null
@@ -0,0 +1,185 @@
+#ifdef USE_FP16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#define T_FLOAT half
+#define T_FLOAT4 half4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
+#else
+#define T_FLOAT float
+#define T_FLOAT4 float4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
+#endif
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+    return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+    global void* p_knl,
+    ulong off_knl,
+    global void* p_src,
+    ulong off_src,
+    global void* p_dst,
+    ulong off_dst,
+    local void* shared,
+    uint Cout, uint Cin, uint N,
+    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+    uint nb01, uint nb02, uint nb03,
+    uint nb11, uint nb12, uint nb13,
+    uint nb1, uint nb2, uint nb3
+) {
+    global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
+    global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
+    global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
+
+    const uint K = Cout;
+    const uint CRS = Cin*KH*KW;
+    const uint NPQ = N*OH*OW;
+
+    const uint lid_k = get_local_id(0);
+    const uint lid_npq = get_local_id(1);
+    const uint tid = lid_npq * WG_K + lid_k;
+
+    const uint B_idx_K = get_group_id(0);
+    const uint B_idx_NPQ = get_group_id(1);
+
+    const uint offset_k = B_idx_K * BS_K;
+    const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+    local T_FLOAT* Ash = (local T_FLOAT*)shared;
+    local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
+
+    T_ACCUM regC[TS_K][TS_NPQ_VEC];
+    for (int i = 0; i < TS_K; ++i) {
+        for (int j = 0; j < TS_NPQ_VEC; ++j) {
+            regC[i][j] = (T_ACCUM)(0.0f);
+        }
+    }
+
+    const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+        const uint offset_crs = B_idx_CRS * BS_CRS;
+
+        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+            const uint k_l = i / BS_CRS;
+            const uint crs_l = i % BS_CRS;
+            const uint k_g = offset_k + k_l;
+            const uint crs_g = offset_crs + crs_l;
+
+            if (k_g < K && crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW*KH);
+                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+            } else {
+                Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
+            }
+        }
+
+        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+            const uint crs_l = i / BS_NPQ_VEC;
+            const uint npq_l_vec = i % BS_NPQ_VEC;
+            const uint crs_g = offset_crs + crs_l;
+
+            T_FLOAT4 val = (T_FLOAT4)(0.0f);
+            if (crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW * KH);
+                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx = npq_g / (OH * OW);
+                        const uint pq_idx = npq_g % (OH * OW);
+                        const uint OH_idx = pq_idx / OW;
+                        const uint OW_idx = pq_idx % OW;
+                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+                            ((T_FLOAT*)&val)[v] = src_data[src_idx];
+                        }
+                    }
+                }
+            }
+            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        #pragma unroll
+        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+            T_FLOAT regA[TS_K];
+            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+            }
+
+            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+                T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+        if (k_g >= K) continue;
+
+        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+            const uint N_idx = npq_g_base / (OH * OW);
+            const uint pq_idx = npq_g_base % (OH * OW);
+            const uint OH_idx = pq_idx / OW;
+            const uint OW_idx = pq_idx % OW;
+
+            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+                VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+            } else {
+                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = npq_g_base + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx_s = npq_g / (OH*OW);
+                        const uint pq_idx_s = npq_g % (OH*OW);
+                        const uint OH_idx_s = pq_idx_s / OW;
+                        const uint OW_idx_s = pq_idx_s % OW;
+                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+                        dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
+                    }
+                }
+            }
+        }
+    }
+}
diff --git a/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/src/ggml-opencl/kernels/conv2d_f16_f32.cl
new file mode 100644 (file)
index 0000000..cb05637
--- /dev/null
@@ -0,0 +1,176 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+    return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+    global void* p_knl,
+    ulong off_knl,
+    global void* p_src,
+    ulong off_src,
+    global void* p_dst,
+    ulong off_dst,
+    local void* shared,
+    uint Cout, uint Cin, uint N,
+    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+    uint nb01, uint nb02, uint nb03,
+    uint nb11, uint nb12, uint nb13,
+    uint nb1, uint nb2, uint nb3
+) {
+    global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
+    global float* src_data = (global float*) ((global char*)p_src + off_src);
+    global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
+
+    const uint K = Cout;
+    const uint CRS = Cin*KH*KW;
+    const uint NPQ = N*OH*OW;
+
+    const uint lid_k = get_local_id(0);
+    const uint lid_npq = get_local_id(1);
+    const uint tid = lid_npq * WG_K + lid_k;
+
+    const uint B_idx_K = get_group_id(0);
+    const uint B_idx_NPQ = get_group_id(1);
+
+    const uint offset_k = B_idx_K * BS_K;
+    const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+    local half* Ash = (local half*)shared;
+    local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
+
+    T_ACCUM regC[TS_K][TS_NPQ_VEC];
+    for (int i = 0; i < TS_K; ++i) {
+        for (int j = 0; j < TS_NPQ_VEC; ++j) {
+            regC[i][j] = (T_ACCUM)(0.0f);
+        }
+    }
+
+    const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+        const uint offset_crs = B_idx_CRS * BS_CRS;
+
+        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+            const uint k_l = i / BS_CRS;
+            const uint crs_l = i % BS_CRS;
+            const uint k_g = offset_k + k_l;
+            const uint crs_g = offset_crs + crs_l;
+
+            if (k_g < K && crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW*KH);
+                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+            } else {
+                Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
+            }
+        }
+
+        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+            const uint crs_l = i / BS_NPQ_VEC;
+            const uint npq_l_vec = i % BS_NPQ_VEC;
+            const uint crs_g = offset_crs + crs_l;
+
+            float4 val = (float4)(0.0f);
+            if (crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW * KH);
+                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx = npq_g / (OH * OW);
+                        const uint pq_idx = npq_g % (OH * OW);
+                        const uint OH_idx = pq_idx / OW;
+                        const uint OW_idx = pq_idx % OW;
+                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+                            ((float*)&val)[v] = src_data[src_idx];
+                        }
+                    }
+                }
+            }
+            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        #pragma unroll
+        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+            half regA[TS_K];
+            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+            }
+
+            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+                float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+        if (k_g >= K) continue;
+
+        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+            const uint N_idx = npq_g_base / (OH * OW);
+            const uint pq_idx = npq_g_base % (OH * OW);
+            const uint OH_idx = pq_idx / OW;
+            const uint OW_idx = pq_idx % OW;
+
+            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+                vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+            } else {
+                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = npq_g_base + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx_s = npq_g / (OH*OW);
+                        const uint pq_idx_s = npq_g % (OH*OW);
+                        const uint OH_idx_s = pq_idx_s / OW;
+                        const uint OW_idx_s = pq_idx_s % OW;
+                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+                        dst_data[dst_idx_s] = ((float*)&res)[v];
+                    }
+                }
+            }
+        }
+    }
+}