cl_kernel kernel_scale_f32, kernel_scale_f32_4;
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
- cl_kernel kernel_mean_f32;
+ cl_kernel kernel_mean_f32, kernel_mean_f32_4;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
cl_kernel kernel_solve_tri_f32;
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32;
- cl_kernel kernel_sum_rows_f32;
+ cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
cl_kernel kernel_repeat_f32;
cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
+ CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
+ CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err));
GGML_LOG_CONT(".");
}
}
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
- return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
+ return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * q = op->src[0];
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
- cl_kernel kernel = backend_ctx->kernel_mean_f32;
+ cl_kernel kernel;
+
+ const bool is_c4 = ne00 % 4 == 0;
+ if (is_c4) {
+ kernel = backend_ctx->kernel_mean_f32_4;
+ } else {
+ kernel = backend_ctx->kernel_mean_f32;
+ }
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
- size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+ size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
- cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
+ cl_kernel kernel;
+
+ const bool is_c4 = ne00 % 4 == 0;
+ if (is_c4) {
+ kernel = backend_ctx->kernel_sum_rows_f32_4;
+ } else {
+ kernel = backend_ctx->kernel_sum_rows_f32;
+ }
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
- size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+ size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
kernel void kernel_mean_f32(
- global float * src0,
+ global char * src0,
ulong offset0,
- global float * dst,
+ global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb2,
ulong nb3
) {
- src0 = (global float *)((global char *)src0 + offset0);
- dst = (global float *)((global char *)dst + offsetd);
+ src0 = src0 + offset0;
+ dst = dst + offsetd;
- int i3 = get_global_id(2);
- int i2 = get_global_id(1);
- int i1 = get_global_id(0);
+ const int i3 = get_group_id(2);
+ const int i2 = get_group_id(1);
+ const int i1 = get_group_id(0);
+
+ const int lid = get_local_id(0);
+ const int lsize = get_local_size(0);
+
+ const uint sg_size = get_sub_group_size();
+ const uint sg_id = get_sub_group_id();
+ const uint sg_lid = get_sub_group_local_id();
+
+ __local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
- global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
- global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ if(sg_id == 0){
+ lmem[sg_lid] = 0.0f;
+ }
- float row_sum = 0;
+ global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
- for (int i0 = 0; i0 < ne00; i0++) {
- row_sum += src_row[i0];
+ float sumf = 0.0f;
+
+ for (int i0 = lid; i0 < ne00; i0 += lsize) {
+ sumf += src_row[i0];
}
- dst_row[0] = row_sum / ne00;
+ sumf = sub_group_reduce_add(sumf);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if(sg_lid == 0){
+ lmem[sg_id] = sumf;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ sumf = lmem[sg_lid];
+ sumf = sub_group_reduce_add(sumf);
+
+ if (lid == 0) {
+ dst_row[0] = sumf / ne00;
+ }
+}
+
+kernel void kernel_mean_f32_4(
+ global char * src0,
+ ulong offset0,
+ global char * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne03,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = src0 + offset0;
+ dst = dst + offsetd;
+
+ const int i3 = get_group_id(2);
+ const int i2 = get_group_id(1);
+ const int i1 = get_group_id(0);
+
+ const int lid = get_local_id(0);
+ const int lsize = get_local_size(0);
+
+ const uint sg_size = get_sub_group_size();
+ const uint sg_id = get_sub_group_id();
+ const uint sg_lid = get_sub_group_local_id();
+
+ __local float lmem[MAX_SUBGROUPS];
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ if(sg_id == 0){
+ lmem[sg_lid] = 0.0f;
+ }
+
+ global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float4 sum_vec = (float4)0.0f;
+
+ for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+ sum_vec += src_row[i0];
+ }
+
+ float sumf = dot(sum_vec, (float4)(1.0f));
+ sumf = sub_group_reduce_add(sumf);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if(sg_lid == 0){
+ lmem[sg_id] = sumf;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ sumf = lmem[sg_lid];
+ sumf = sub_group_reduce_add(sumf);
+
+ if (lid == 0) {
+ dst_row[0] = sumf / ne00;
+ }
}
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
kernel void kernel_sum_rows_f32(
- global float * src0,
+ global char * src0,
ulong offset0,
- global float * dst,
+ global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb2,
ulong nb3
) {
- src0 = (global float *)((global char *)src0 + offset0);
- dst = (global float *)((global char *)dst + offsetd);
+ src0 = src0 + offset0;
+ dst = dst + offsetd;
- int i3 = get_global_id(2);
- int i2 = get_global_id(1);
- int i1 = get_global_id(0);
+ const int i3 = get_group_id(2);
+ const int i2 = get_group_id(1);
+ const int i1 = get_group_id(0);
+
+ const int lid = get_local_id(0);
+ const int lsize = get_local_size(0);
+
+ const uint sg_size = get_sub_group_size();
+ const uint sg_id = get_sub_group_id();
+ const uint sg_lid = get_sub_group_local_id();
+
+ __local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
- global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
- global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ if(sg_id == 0){
+ lmem[sg_lid] = 0.0f;
+ }
- float row_sum = 0;
+ global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
- for (int i0 = 0; i0 < ne00; i0++) {
- row_sum += src_row[i0];
+ float sumf = 0.0f;
+
+ for (int i0 = lid; i0 < ne00; i0 += lsize) {
+ sumf += src_row[i0];
}
- dst_row[0] = row_sum;
+ sumf = sub_group_reduce_add(sumf);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if(sg_lid == 0){
+ lmem[sg_id] = sumf;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ sumf = lmem[sg_lid];
+ sumf = sub_group_reduce_add(sumf);
+
+ if (lid == 0) {
+ dst_row[0] = sumf;
+ }
+}
+
+kernel void kernel_sum_rows_f32_4(
+ global char * src0,
+ ulong offset0,
+ global char * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne03,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = src0 + offset0;
+ dst = dst + offsetd;
+
+ const int i3 = get_group_id(2);
+ const int i2 = get_group_id(1);
+ const int i1 = get_group_id(0);
+
+ const int lid = get_local_id(0);
+ const int lsize = get_local_size(0);
+
+ const uint sg_size = get_sub_group_size();
+ const uint sg_id = get_sub_group_id();
+ const uint sg_lid = get_sub_group_local_id();
+
+ __local float lmem[MAX_SUBGROUPS];
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ if(sg_id == 0){
+ lmem[sg_lid] = 0.0f;
+ }
+
+ global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float4 sum_vec = (float4)0.0f;
+
+ for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+ sum_vec += src_row[i0];
+ }
+
+ float sumf = dot(sum_vec, (float4)(1.0f));
+ sumf = sub_group_reduce_add(sumf);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if(sg_lid == 0){
+ lmem[sg_id] = sumf;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ sumf = lmem[sg_lid];
+ sumf = sub_group_reduce_add(sumf);
+
+ if (lid == 0) {
+ dst_row[0] = sumf;
+ }
}