cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
+ cl_kernel kernel_mul_mv_q4_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
+ cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
std::vector<ProfilingInfo> profiling_info;
GGML_LOG_CONT(".");
}
+ // mul_mv_q4_k_f32
+ {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+ const std::string kernel_src {
+ #include "mul_mv_q4_k_f32.cl.h"
+ };
+#else
+ const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
+#endif
+ cl_program prog =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+ CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
+ CL_CHECK(clReleaseProgram(prog));
+ GGML_LOG_CONT(".");
+ }
+
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
GGML_LOG_CONT(".");
}
+ // mul_mm_q6_k_f32_l4_lm
+ {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+ const std::string kernel_src {
+ #include "mul_mm_q6_k_f32_l4_lm.cl.h"
+ };
+#else
+ const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
+#endif
+ cl_program prog =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+ CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
+ CL_CHECK(clReleaseProgram(prog));
+ GGML_LOG_CONT(".");
+ }
+
// mul_mm_f16_f32_kq_kqv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
} else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
+ op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q6_K) {
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
+ case GGML_TYPE_Q6_K: {
+ if (ne11 < 32) {
+ break;
+ }
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+ break;
+ }
+
+ kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
+ nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+ int batch_stride_a = ne00*ne01;
+ int batch_stride_b = ne10*ne11;
+ int batch_stride_d = ne0*ne1;
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
+
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+ size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+ size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+ return;
+ }
default:
break;
}
}
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q4_K: {
+ kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
+
+ if (backend_ctx->gpu_family == INTEL) {
+ nth0 = 16;
+ nth1 = 1;
+ ndst = 4;
+ } else if (backend_ctx->gpu_family == ADRENO) {
+ nth0 = 64;
+ nth1 = 1;
+ ndst = 4;
+ } else {
+ GGML_ASSERT(false && "TODO: Unknown GPU");
+ }
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
+ break;
+ }
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
#ifdef GGML_OPENCL_SOA_Q
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q4_K) {
- GGML_ASSERT(false && "not implemented");
+ size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+ size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q3_K) {
GGML_ASSERT(false && "not implemented");
} else if (src0t == GGML_TYPE_Q5_K) {
--- /dev/null
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 2
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q6_k_f32_l4_lm(
+ global uchar * src0_ql,
+ global uchar * src0_qh,
+ global char * src0_s,
+ global half * src0_d,
+ global float4 * src1,
+ ulong offset1,
+ global float * dst,
+ ulong offsetd,
+
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne11,
+ int ne12,
+
+ int stride_a,
+ int stride_b,
+ int stride_d,
+
+ int batch_stride_a,
+ int batch_stride_b,
+ int batch_stride_d,
+
+ int r2,
+ int r3
+) {
+ src1 = (global float4*)((global char*)src1 + offset1);
+ dst = (global float *)((global char*)dst + offsetd);
+
+ local float buf_a[BM * BK];
+ local float buf_b[BN * BK];
+
+ const int batch_idx = get_global_id(2);
+
+ const int i13 = batch_idx / ne12;
+ const int i12 = batch_idx % ne12;
+
+ const int i03 = i13 / r3;
+ const int i02 = i12 / r2;
+
+ const int batch_idx_a = i03 * ne02 + i02;
+
+ const int ir = get_group_id(0);
+ const int ic = get_group_id(1);
+
+ const int tid = get_local_id(0);
+ const int th_r = tid % (BM / TM);
+ const int th_c = tid / (BM / TM);
+
+ const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+ const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+ const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+ const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+ const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+ const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+ int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+ int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+ float sums[TM * TN];
+ float cache_a[TM];
+ float cache_b[TN];
+
+ for (int i = 0; i < TM * TN; i++) {
+ sums[i] = 0.0f;
+ }
+
+ for (int block = 0; block < ne00; block += BK) {
+ for (int l = 0; l < BM; l += loadstride_a) {
+ if (ir*BM + loadc_a + l < ne01) {
+ int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+
+ int ib = idx / 128; // 2 values per idx
+ int iqs = idx % 128; // 0..127
+
+ int n = iqs / 64; // 0,1
+ int b = (iqs % 64) / 32; // 0,1
+ int is_b = (iqs % 16) / 8; // 0,1
+ int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
+ int is = 8 * n + qhshift + is_b; // 0..15
+ int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
+ int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
+
+ float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
+
+ buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
+ buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
+ } else {
+ buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
+ buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
+ }
+ }
+
+ for (int l = 0; l < BN; l += loadstride_b) {
+ if (ic*BN + loadc_b + l < ne11) {
+ int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+ buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+ buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+ buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+ buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+ } else {
+ buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+ buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+ buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+ buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+ }
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ pos_a += BK / LOAD_VEC_A;
+ pos_b += BK / LOAD_VEC_B;
+
+ for (int i = 0; i < BK; i++) {
+ for (int j = 0; j < TM; j++) {
+ cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+ }
+
+ for (int j = 0; j < TN; j++) {
+ cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+ }
+
+ for (int cc = 0; cc < TN; cc++) {
+ for (int cr = 0; cr < TM; cr++) {
+ const int sums_idx = cc*TM + cr;
+ sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+ }
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
+
+ const int dr = ir * BM + th_r * TM;
+ const int dc = ic * BN + th_c * TN;
+
+ const int offsets = batch_idx * batch_stride_d;
+
+ for (int cc = 0; cc < TN; cc++) {
+ for (int cr = 0; cr < TM; cr++) {
+ if (dr + cr < ne01 && dc + cc < ne11) {
+ dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+ }
+ }
+ }
+}
--- /dev/null
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+//------------------------------------------------------------------------------
+// block_q4_K
+//------------------------------------------------------------------------------
+#define QK_K 256
+#define K_SCALE_SIZE 12
+
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+
+ uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uchar qs[QK_K/2]; // 4-bit quants
+} block_q4_K;
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // number of rows each SIMD group works on
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // SIMD group size
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+#undef BLOCK_STRIDE
+// number of (super) blocks each subgroup processes
+// each thread in a subgroup processes a block (32 weights)
+#define BLOCK_STRIDE (N_SIMDWIDTH/8)
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q4_K_f32(
+ global char * src0,
+ int offset0,
+ global char * src1,
+ int offset1,
+ global char * dst,
+ int offsetd,
+ int ne00,
+ int ne01,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne12,
+ ulong nb11,
+ ulong nb12,
+ ulong nb13,
+ int ne0,
+ int ne1,
+ int r2,
+ int r3
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ ushort kmask1 = 0x3f3f;
+ ushort kmask2 = 0x0f0f;
+ ushort kmask3 = 0xc0c0;
+
+ int ix = get_sub_group_local_id()/8; // super block index
+ int it = get_sub_group_local_id()%8; // block index (inside super block)
+ int iq = it/4; // 0 or 1 - first or second half of the super block
+ int ir = it%4; // 0...3 - block index in the half super block
+
+ int nb = ne00/QK_K;
+
+ int r0 = get_group_id(0);
+ int r1 = get_group_id(1);
+ int im = get_group_id(2);
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+ int i12 = im%ne12;
+ int i13 = im/ne12;
+
+ int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
+ global float * y = (global float *) (src1 + offset_src1);
+
+ float yl[16];
+ float yh[16];
+ float sumf[N_DST] = {0.f};
+ float all_sum;
+
+ global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+ ushort sc16[4];
+ uchar * sc8 = (uchar *)sc16;
+
+ for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+0] = y4[i+0];
+ sumy.s0 += yl[i+0];
+
+ yl[i+8] = y4[i+32];
+ sumy.s1 += yl[i+8];
+
+ yh[i+0] = y4[i+128];
+ sumy.s2 += yh[i+0];
+
+ yh[i+8] = y4[i+160];
+ sumy.s3 += yh[i+8];
+ }
+
+ global ushort * sc = (global ushort *)x[ib].scales + iq;
+ global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
+ global half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+ sc16[0] = sc[0] & kmask1;
+ sc16[1] = sc[2] & kmask1;
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+ global ushort * q2 = q1 + 32;
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
+ acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
+ acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
+ acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
+ acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
+ acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
+ acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
+ acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
+ (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
+ (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
+ (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
+ dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
+
+ q1 += nb01/2;
+ sc += nb01/2;
+ dh += nb01/2;
+ }
+
+ y4 += BLOCK_STRIDE * QK_K;
+ }
+
+ global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = sub_group_reduce_add(sumf[row]);
+ if (first_row + row < ne01) {
+ if (get_sub_group_local_id() == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+ }
+}