]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: add tiled mul_mat_f16_f32 (llama/14535)
authorrmatif <redacted>
Thu, 10 Jul 2025 21:58:12 +0000 (23:58 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* add tiled mul_mat_f16_f32

* fix trailing whitespace

* add insightful comments

src/ggml-opencl/CMakeLists.txt
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/mul_mat_f16_f32.cl [new file with mode: 0644]

index 03e77650d7ee1d00dedd9510acfc331024a9cb6b..ec5d8cf59556b9a88020edbc665c3c3a5dc180b6 100644 (file)
@@ -104,6 +104,7 @@ set(GGML_OPENCL_KERNELS
     tanh
     pad
     repeat
+    mul_mat_f16_f32
 )
 
 foreach (K ${GGML_OPENCL_KERNELS})
index 91b66c3bd74217d4beee43bec1966b5fc32b68d4..58830b733a8af444463007d3cd5b4b58f157da32 100644 (file)
@@ -368,6 +368,7 @@ struct ggml_backend_opencl_context {
     cl_program program_mul_mv_f16_f32;
     cl_program program_mul_mv_f32_f32;
     cl_program program_mul;
+    cl_program program_mul_mat_f16_f32_tiled;
     cl_program program_div;
     cl_program program_sub;
     cl_program program_norm;
@@ -422,6 +423,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mat_f16_f32_1row;
     cl_kernel kernel_mul_mat_f16_f32;
     cl_kernel kernel_mul_mat_f16_f32_l4;
+    cl_kernel kernel_mul_mat_f16_f32_tiled;
     cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
     cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
     cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@@ -1015,6 +1017,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mat_f16_f32_tiled
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mat_f16_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
+#endif
+        backend_ctx->program_mul_mat_f16_f32_tiled =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -4927,6 +4945,58 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
 }
 
+static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const int M = src0->ne[1];
+    const int N = src1->ne[1];
+    const int K = src0->ne[0];
+
+    cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int),      &M));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int),      &N));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),      &K));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
+
+    // Tiling parameters. These need to be tuned for optimal performance.
+    // They must match the #defines in the kernel mul_mat_f16_f32.cl.
+    //
+    // OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
+    // TPWM / TPWN: Threads per Work-group. This is the work-group size.
+    // OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
+    //
+    // The following relationships must hold:
+    //   OPWM = TPWM * OPTM
+    //   OPWN = TPWN * OPTN
+    //
+    const int OPWM = 64;
+    const int OPWN = 64;
+    const int TPWM = 16;
+    const int TPWN = 8;
+
+    size_t local_work_size[2] = { TPWM, TPWN };
+    size_t global_work_size[2] = {
+        (size_t) ((M + OPWM - 1) / OPWM) * TPWM,
+        (size_t) ((N + OPWN - 1) / OPWN) * TPWN,
+    };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -4940,6 +5010,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
+     if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
+        src0->ne[1] > 32 &&   // M > 32
+        src1->ne[1] > 32 &&   // N > 32
+        src0->ne[0] > 32 &&   // K > 32
+        src0->ne[2] == 1 && src0->ne[3] == 1 &&
+        src1->ne[2] == 1 && src1->ne[3] == 1 &&
+        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
+        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
+        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
+        return;
+    }
+
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
diff --git a/src/ggml-opencl/kernels/mul_mat_f16_f32.cl b/src/ggml-opencl/kernels/mul_mat_f16_f32.cl
new file mode 100644 (file)
index 0000000..73a8884
--- /dev/null
@@ -0,0 +1,130 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define OPWM 64
+#define OPWN 64
+#define CPWK 8
+#define OPTM 4
+#define OPTN 8
+
+#define WG_M (OPWM / OPTM)
+#define WG_N (OPWN / OPTN)
+#define VEC_K (CPWK / 4)
+
+REQD_SUBGROUP_SIZE_128
+__kernel void mul_mat_f16_f32(
+    const int M, const int N, const int K,
+    __global const void* A_void, ulong A_offset,
+    __global const void* B_void, ulong B_offset,
+    __global       void* C_void, ulong C_offset) {
+
+    __global const half*  A = (__global const half* )((__global const char*)A_void + A_offset);
+    __global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
+    __global       float* C = (__global       float*)((__global       char*)C_void + C_offset);
+
+    const int lidm = get_local_id(0);
+    const int lidn = get_local_id(1);
+    const int lid = lidn * WG_M + lidm;
+
+    const int offsetM = get_group_id(0) * OPWM;
+    const int offsetN = get_group_id(1) * OPWN;
+
+    __local half4  Alocal[OPWM][VEC_K];
+    __local float4 Blocal[OPWN][VEC_K];
+
+    float sum[OPTM][OPTN];
+
+    for (int wm = 0; wm < OPTM; wm++) {
+        for (int wn = 0; wn < OPTN; wn++) {
+            sum[wm][wn] = 0.0f;
+        }
+    }
+
+    const int numTiles = (K + CPWK - 1) / CPWK;
+
+    const int load_row_a = lid % OPWM;
+    const int load_vec_k_a = lid / OPWM;
+    const int global_row_a = offsetM + load_row_a;
+
+    const int load_row_b = lid % OPWN;
+    const int load_vec_k_b = lid / OPWN;
+    const int global_row_b = offsetN + load_row_b;
+
+    for (int t = 0; t < numTiles; t++) {
+        const int k_start = t * CPWK;
+        const int k_vec_start_a = k_start + load_vec_k_a * 4;
+        const int k_vec_start_b = k_start + load_vec_k_b * 4;
+
+        if (global_row_a < M && k_vec_start_a < K) {
+            if (k_vec_start_a + 3 < K) {
+                Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
+            } else {
+                half4 tempA = (half4)(0.0h);
+                if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
+                if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
+                if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
+                Alocal[load_row_a][load_vec_k_a] = tempA;
+            }
+        } else {
+            Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
+        }
+
+        if (global_row_b < N && k_vec_start_b < K) {
+            if (k_vec_start_b + 3 < K) {
+                Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
+            } else {
+                float4 tempB = (float4)(0.0f);
+                if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
+                if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
+                if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
+                Blocal[load_row_b][load_vec_k_b] = tempB;
+            }
+        } else {
+            Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        #pragma unroll
+        for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
+            float4 a_fvecs[OPTM];
+            int current_row_a = lidm;
+            for (int wm = 0; wm < OPTM; wm++) {
+                a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
+                current_row_a += WG_M;
+            }
+
+            float4 b_fvecs[OPTN];
+            int current_row_b = lidn;
+            for (int wn = 0; wn < OPTN; wn++) {
+                b_fvecs[wn] = Blocal[current_row_b][k_vec];
+                current_row_b += WG_N;
+            }
+
+            for (int wm = 0; wm < OPTM; wm++) {
+                for (int wn = 0; wn < OPTN; wn++) {
+                    sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    for (int wm = 0; wm < OPTM; wm++) {
+        int globalRow = offsetM + lidm + wm * WG_M;
+        if (globalRow < M) {
+            for (int wn = 0; wn < OPTN; wn++) {
+                int globalCol = offsetN + lidn + wn * WG_N;
+                if (globalCol < N) {
+                    C[globalCol * M + globalRow] = sum[wm][wn];
+                }
+            }
+        }
+    }
+}