]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: add `mul_mat_f32_f32_l4_lm` and `mul_mat_f16_f32_l4_lm` (llama/14809)
authorlhez <redacted>
Wed, 30 Jul 2025 21:56:55 +0000 (14:56 -0700)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-opencl/CMakeLists.txt
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl [new file with mode: 0644]
src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl [new file with mode: 0644]

index 015fa8f06824e52ea451d3ceda7ee174a435d57b..3adea83615437b8951e7602285a8c4198e5e2c9a 100644 (file)
@@ -82,6 +82,8 @@ set(GGML_OPENCL_KERNELS
     mul_mv_q4_0_f32_1d_16x_flat
     mul_mv_q6_k
     mul_mv_id_q4_0_f32_8x_flat
+    mul_mm_f32_f32_l4_lm
+    mul_mm_f16_f32_l4_lm
     mul
     norm
     relu
index c87a32383c8735ae49bd387f16ad16c82f5a243d..984d35a2ecf762b6fe110b4af3d914e286ebe9f1 100644 (file)
@@ -33,6 +33,7 @@
 #undef MAX
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 
 #define UNUSED(x) (void)(x)
 
@@ -396,6 +397,8 @@ struct ggml_backend_opencl_context {
     cl_program program_conv_2d_f16_f32;
     cl_program program_tsembd;
     cl_program program_mul_mv_id_q4_0_f32_8x_flat;
+    cl_program program_mul_mm_f32_f32_l4_lm;
+    cl_program program_mul_mm_f16_f32_l4_lm;
 
     cl_kernel kernel_add, kernel_add_row;
     cl_kernel kernel_mul, kernel_mul_row;
@@ -450,6 +453,8 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_conv_2d_f16_f32;
     cl_kernel kernel_timestep_embedding;
     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
+    cl_kernel kernel_mul_mm_f32_f32_l4_lm;
+    cl_kernel kernel_mul_mm_f16_f32_l4_lm;
 
     std::vector<ProfilingInfo> profiling_info;
 
@@ -1040,6 +1045,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_f32_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_f32_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl");
+#endif
+        backend_ctx->program_mul_mm_f32_f32_l4_lm =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mm_f16_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_f16_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl");
+#endif
+        backend_ctx->program_mul_mm_f16_f32_l4_lm =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -5297,18 +5334,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
-     if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
-        src0->ne[1] > 32 &&   // M > 32
-        src1->ne[1] > 32 &&   // N > 32
-        src0->ne[0] > 32 &&   // K > 32
-        src0->ne[2] == 1 && src0->ne[3] == 1 &&
-        src1->ne[2] == 1 && src1->ne[3] == 1 &&
-        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
-        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
-        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
-        return;
-    }
-
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5655,6 +5680,101 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
     } // if (ne01 && ne1)
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
 
+    // GEMM using local memory
+    // Current BK = 16, so ne00 % 16 == 0
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+        src1t == GGML_TYPE_F32 &&
+        ne00 % 16 == 0 &&
+        ne11 > 1) {
+        switch(src0t) {
+            case GGML_TYPE_F32: {
+                kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
+            case GGML_TYPE_F16: {
+                kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
+            default:
+                break;
+        }
+    }
+
+    if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
+        src0->ne[1] > 32 &&   // M > 32
+        src1->ne[1] > 32 &&   // N > 32
+        src0->ne[0] > 32 &&   // K > 32
+        src0->ne[2] == 1 && src0->ne[3] == 1 &&
+        src1->ne[2] == 1 && src1->ne[3] == 1 &&
+        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
+        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
+        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
+        return;
+    }
+
     if (!ggml_is_transposed(src0) &&
         !ggml_is_transposed(src1) &&
         src1t == GGML_TYPE_F32 &&
diff --git a/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
new file mode 100644 (file)
index 0000000..9599a0e
--- /dev/null
@@ -0,0 +1,132 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 4
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 16
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_f16_f32_l4_lm(
+    global half4 * src0,
+    ulong offset0,
+    global float4 * src1,
+    ulong offset1,
+    global float * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    local half  buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    half  cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
new file mode 100644 (file)
index 0000000..58c5178
--- /dev/null
@@ -0,0 +1,133 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 4
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 16
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_f32_f32_l4_lm(
+    global float4 * src0,
+    ulong offset0,
+    global float4 * src1,
+    ulong offset1,
+    global float * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}