]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: optimize mean and sum_row kernels (llama/19614)
authorshaofeiqi <redacted>
Tue, 17 Feb 2026 21:56:09 +0000 (13:56 -0800)
committerGeorgi Gerganov <redacted>
Wed, 25 Feb 2026 10:32:13 +0000 (12:32 +0200)
* opencl: optimize mean and sum_row kernels

* opencl: add comment for max subgroups

* opencl: format

---------

Co-authored-by: Li He <redacted>
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/mean.cl
src/ggml-opencl/kernels/sum_rows.cl

index ae3f79fd0d60cad424adcf491a970951264bb802..3dd12e177f3d9e7723ddaa223f5251ff4a9378bb 100644 (file)
@@ -484,7 +484,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_scale_f32, kernel_scale_f32_4;
     cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
     cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
-    cl_kernel kernel_mean_f32;
+    cl_kernel kernel_mean_f32, kernel_mean_f32_4;
     cl_kernel kernel_silu, kernel_silu_4;
     cl_kernel kernel_gelu, kernel_gelu_4;
     cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
@@ -543,7 +543,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_solve_tri_f32;
     cl_kernel kernel_im2col_f32, kernel_im2col_f16;
     cl_kernel kernel_argsort_f32_i32;
-    cl_kernel kernel_sum_rows_f32;
+    cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
     cl_kernel kernel_repeat_f32;
     cl_kernel kernel_pad;
     cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
@@ -1837,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
         CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err));
 
         CL_CHECK(clReleaseProgram(prog));
         GGML_LOG_CONT(".");
@@ -1874,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
         CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -3587,7 +3589,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         }
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
-            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 const ggml_tensor * q = op->src[0];
@@ -6400,7 +6402,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_UNUSED(src1);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -6423,7 +6424,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     const cl_ulong nb2  = dst->nb[2];
     const cl_ulong nb3  = dst->nb[3];
 
-    cl_kernel kernel = backend_ctx->kernel_mean_f32;
+    cl_kernel kernel;
+
+    const bool is_c4 = ne00 % 4 == 0;
+    if (is_c4) {
+        kernel = backend_ctx->kernel_mean_f32_4;
+    } else {
+        kernel = backend_ctx->kernel_mean_f32;
+    }
 
     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
@@ -6440,7 +6448,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
     CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)64, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
@@ -11088,7 +11096,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     GGML_UNUSED(src1);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -11111,7 +11118,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     const cl_ulong nb2  = dst->nb[2];
     const cl_ulong nb3  = dst->nb[3];
 
-    cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
+    cl_kernel kernel;
+
+    const bool is_c4 = ne00 % 4 == 0;
+    if (is_c4) {
+        kernel = backend_ctx->kernel_sum_rows_f32_4;
+    } else {
+        kernel = backend_ctx->kernel_sum_rows_f32;
+    }
 
     CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset0));
@@ -11128,7 +11142,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nb2));
     CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)64, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
index 5c3e8bcd863165c9b79bc268ab6c86f68aa10543..7c7e0a587eed874abd129faf75914c147080ac01 100644 (file)
@@ -1,8 +1,13 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
 
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
 kernel void kernel_mean_f32(
-    global float *  src0,
+    global char *  src0,
     ulong           offset0,
-    global float *  dst,
+    global char *  dst,
     ulong           offsetd,
     int             ne00,
     int             ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
     ulong           nb2,
     ulong           nb3
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    dst  = (global float *)((global char *)dst  + offsetd);
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
 
-    int i3 = get_global_id(2);
-    int i2 = get_global_id(1);
-    int i1 = get_global_id(0);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
 
     if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
         return;
     }
 
-    global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    global float * dst_row = (global float *) ((global char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
 
-    float row_sum = 0;
+    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
 
-    for (int i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
+    float sumf = 0.0f;
+
+    for (int i0 = lid; i0 < ne00; i0 += lsize) {
+        sumf += src_row[i0];
     }
 
-    dst_row[0] = row_sum / ne00;
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf / ne00;
+    }
+}
+
+kernel void kernel_mean_f32_4(
+    global char *  src0,
+    ulong           offset0,
+    global char *  dst,
+    ulong           offsetd,
+    int             ne00,
+    int             ne01,
+    int             ne02,
+    int             ne03,
+    ulong           nb01,
+    ulong           nb02,
+    ulong           nb03,
+    ulong           nb1,
+    ulong           nb2,
+    ulong           nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
+
+    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float4 sum_vec = (float4)0.0f;
+
+    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+        sum_vec += src_row[i0];
+    }
+
+    float sumf = dot(sum_vec, (float4)(1.0f));
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf / ne00;
+    }
 }
index c5f7c570f9514ba4dd9716d42633175ff44ee086..84630aa8a303b8b34b73ea90833e799f83e3a9d7 100644 (file)
@@ -1,8 +1,13 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
 
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
 kernel void kernel_sum_rows_f32(
-    global float *  src0,
+    global char *  src0,
     ulong           offset0,
-    global float *  dst,
+    global char *  dst,
     ulong           offsetd,
     int             ne00,
     int             ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
     ulong           nb2,
     ulong           nb3
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    dst  = (global float *)((global char *)dst  + offsetd);
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
 
-    int i3 = get_global_id(2);
-    int i2 = get_global_id(1);
-    int i1 = get_global_id(0);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
 
     if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
         return;
     }
 
-    global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    global float * dst_row = (global float *) ((global char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
 
-    float row_sum = 0;
+    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
 
-    for (int i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
+    float sumf = 0.0f;
+
+    for (int i0 = lid; i0 < ne00; i0 += lsize) {
+        sumf += src_row[i0];
     }
 
-    dst_row[0] = row_sum;
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf;
+    }
+}
+
+kernel void kernel_sum_rows_f32_4(
+    global char *  src0,
+    ulong           offset0,
+    global char *  dst,
+    ulong           offsetd,
+    int             ne00,
+    int             ne01,
+    int             ne02,
+    int             ne03,
+    ulong           nb01,
+    ulong           nb02,
+    ulong           nb03,
+    ulong           nb1,
+    ulong           nb2,
+    ulong           nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
+
+    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float4 sum_vec = (float4)0.0f;
+
+    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+        sum_vec += src_row[i0];
+    }
+
+    float sumf = dot(sum_vec, (float4)(1.0f));
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf;
+    }
 }