cl_program program_mul_mv_f32_f32;
cl_program program_mul;
cl_program program_mul_mat_f16_f32_tiled;
+ cl_program program_mul_mm_f16_f32_kqv;
+ cl_program program_mul_mm_f16_f32_kq;
cl_program program_div;
cl_program program_sub;
cl_program program_norm;
cl_kernel kernel_mul_mat_f16_f32;
cl_kernel kernel_mul_mat_f16_f32_l4;
cl_kernel kernel_mul_mat_f16_f32_tiled;
+ cl_kernel kernel_mul_mm_f16_f32_kqv;
+ cl_kernel kernel_mul_mm_f16_f32_kq;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
GGML_LOG_CONT(".");
}
+ // mul_mm_f16_f32_kq_kqv
+ {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+ const std::string kernel_src {
+ #include "mul_mm_f16_f32_kq_kqv.cl.h"
+ };
+#else
+ const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
+#endif
+ backend_ctx->program_mul_mm_f16_f32_kqv =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
+ backend_ctx->program_mul_mm_f16_f32_kq =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+ CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
+ CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
+ GGML_LOG_CONT(".");
+ }
+
// mul
{
#ifdef GGML_OPENCL_EMBED_KERNELS
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}
+static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+
+ const cl_ulong nb01 = src0->nb[1];
+ const cl_ulong nb02 = src0->nb[2];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+
+ const cl_ulong nb10 = src1->nb[0];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+
+ GGML_ASSERT(ne00 == ne10);
+
+ cl_kernel kernel;
+ cl_context context = backend_ctx->context;
+
+ cl_int status;
+ cl_image_format img_fmt_1d;
+ cl_image_desc img_desc_1d;
+ cl_buffer_region region;
+ cl_mem A_image1d;
+ cl_mem A_sub_buffer;
+ cl_mem B_sub_buffer;
+ cl_mem D_image1d;
+ cl_mem D_sub_buffer;
+
+ int M = ne01;
+ int N = ne1;
+ int K = ne00;
+
+ if (nb01 > nb02) {
+ // KQ
+ kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
+ } else {
+ // KQV
+ kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
+ }
+ // create sub-buffer for A
+ // <--------------------------------------------> //
+ extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
+
+ region.origin = (extra0->offset);
+ if (nb01 > nb02) {
+ // KQ
+ region.size = nb01 * ne01;
+ } else {
+ // KQV
+ region.size = nb02 * ne02;
+ }
+
+ A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+ CL_CHECK(status);
+
+ // <--------------------------------------------> //
+
+ // create sub-buffer for B
+ // <--------------------------------------------> //
+ region.origin = (extra1->offset);
+ region.size = nb10 * ne10 * ne11 * ne12;
+ B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+ CL_CHECK(status);
+ // <--------------------------------------------> //
+
+ img_fmt_1d = {CL_RGBA, CL_FLOAT};
+ memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+ img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+ if (nb01 > nb02) {
+ img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
+ }
+ else {
+ img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
+ }
+ img_desc_1d.buffer = A_sub_buffer;
+ A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+ CL_CHECK(status);
+
+ // create sub-buffer for output C
+ // <--------------------------------------------> //
+ region.origin = (extrad->offset);
+ region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
+ D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+ CL_CHECK(status);
+ // <--------------------------------------------> //
+
+ // create image for C output
+ // <--------------------------------------------> //
+ img_fmt_1d = {CL_R, CL_FLOAT};
+ memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+ img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+ img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
+ img_desc_1d.buffer = D_sub_buffer;
+ D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+ CL_CHECK(status);
+ // <--------------------------------------------> //
+
+ int offset_src0 = 0;
+ int offset_src1 = 0;
+
+ // set kernel args
+ // <--------------------------------------------> //
+ cl_uint k_arg = 0;
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01));
+
+ size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
+ size_t local_work_size[3] = {64, 1, 2};
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+ // deallocate sub buffers and images
+ // <--------------------------------------------> //
+ CL_CHECK(clReleaseMemObject(A_image1d));
+ CL_CHECK(clReleaseMemObject(D_image1d));
+ CL_CHECK(clReleaseMemObject(A_sub_buffer));
+ CL_CHECK(clReleaseMemObject(B_sub_buffer));
+ CL_CHECK(clReleaseMemObject(D_sub_buffer));
+}
+
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
cl_context context = backend_ctx->context;
+ if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
+ if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
+ ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
+ return;
+ }
+ }
+
if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
// init CL objects
--- /dev/null
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+
+#define LM_FIRST_256B 0
+#define LM_SECOND_256B 64
+#define LM_THIRD_256B 128
+#define LM_FOURTH_256B 192
+
+
+inline float16 mm_load_a(
+ image1d_buffer_t matrix_A,
+ uint subMatrixAStartInElements,
+ int nb01,
+ int line_stride_matrix_A_in_bytes
+) {
+ __private float8 regA;
+ size_t sub_block_id_m = get_local_id(0);
+
+#ifdef KQV
+ uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
+#else // KQ
+ uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
+#endif
+
+ regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
+ regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
+
+ return convert_float16(as_half16(regA));
+}
+
+inline float4 alu_32(
+ float16 regA,
+ __local float4* matrix_B_vec
+) {
+
+ __private float4 rC = 0;
+ int i = get_sub_group_id() * 64;
+
+ rC += regA.s0 * matrix_B_vec[i];
+ rC += regA.s1 * matrix_B_vec[i + 16];
+ rC += regA.s4 * matrix_B_vec[i + 1];
+ rC += regA.s5 * matrix_B_vec[i + 17];
+ rC += regA.s8 * matrix_B_vec[i + 2];
+ rC += regA.s9 * matrix_B_vec[i + 18];
+ rC += regA.sc * matrix_B_vec[i + 3];
+ rC += regA.sd * matrix_B_vec[i + 19];
+
+ i += 32;
+
+ rC += regA.s2 * matrix_B_vec[i];
+ rC += regA.s3 * matrix_B_vec[i + 16];
+ rC += regA.s6 * matrix_B_vec[i + 1];
+ rC += regA.s7 * matrix_B_vec[i + 17];
+ rC += regA.sa * matrix_B_vec[i + 2];
+ rC += regA.sb * matrix_B_vec[i + 18];
+ rC += regA.se * matrix_B_vec[i + 3];
+ rC += regA.sf * matrix_B_vec[i + 19];
+
+ return rC;
+}
+
+inline float16 alu_16(
+ float16 regA,
+ __local float* matrix_B_local
+) {
+ float16 out;
+ __local float4* matrix_B_vec = (__local float4*)matrix_B_local;
+
+ out.s0123 = alu_32(regA, matrix_B_vec);
+ out.s4567 = alu_32(regA, matrix_B_vec + 4);
+ out.s89ab = alu_32(regA, matrix_B_vec + 8);
+ out.scdef = alu_32(regA, matrix_B_vec + 12);
+
+ return out;
+}
+
+inline void mm_mad(
+ __local float* matrix_B_local,
+ float16 regA,
+ float8 regB,
+ uint b_localOffsetInWords,
+ float16* regC0_ptr,
+ float16* regC1_ptr
+) {
+ int offset = b_localOffsetInWords + get_sub_group_id() * 256;
+
+ matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
+ matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
+ matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
+ matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
+
+ float16 add0 = alu_16(regA, matrix_B_local);
+ *regC0_ptr += add0;
+
+ matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
+ matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
+ matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
+ matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
+
+ float16 add1 = alu_16(regA, matrix_B_local);
+ *regC1_ptr += add1;
+}
+
+inline void mm_store_c_N(
+ __write_only image1d_buffer_t matrix_C,
+ float16 regC0,
+ float16 regC1,
+ uint subMatrixCStartInElements,
+ int line_stride_matrix_C_in_bytes,
+ int mask
+) {
+ size_t sub_block_id_m = get_local_id(0);
+
+ uint strideInWords = line_stride_matrix_C_in_bytes/4;
+ uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
+
+ uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
+ uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
+ uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
+ uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
+ uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
+ uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
+ uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
+ uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
+ uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
+ uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
+ uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
+ uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
+ uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
+ uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
+ uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
+ uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
+ uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
+ uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
+ uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
+ uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
+ uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
+ uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
+ uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
+ uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
+ uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
+ uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
+ uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
+ uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
+ uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
+ uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
+ uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
+
+ if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
+ if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
+ if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
+ if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
+ if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
+ if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
+ if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
+ if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
+ if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
+ if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
+ if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
+ if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
+ if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
+ if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
+ if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
+ if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
+ if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
+ if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
+ if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
+ if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
+ if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
+ if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
+ if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
+ if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
+ if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
+ if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
+ if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
+ if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
+ if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
+ if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
+ if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
+ if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
+}
+
+#define TILESIZE_K 16
+#define TILESIZE_M 64
+#define TILESIZE_N 32
+#ifdef KQV
+__kernel void mul_mm_f16_f32_kqv(
+#else
+__kernel void mul_mm_f16_f32_kq(
+#endif
+ __read_only image1d_buffer_t matrix_A,
+ int offset0,
+ __global float* matrix_B,
+ int offset1,
+ __write_only image1d_buffer_t matrix_C,
+ int offsetd,
+ int M, int K, int N,
+ int D_A,
+ int D_B,
+ int nb01
+) {
+
+ uint block_id_m = get_global_id(1);
+ uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
+ uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
+
+ __private float16 regA;
+ __private float8 regB;
+ __private float16 regC0;
+ __private float16 regC1;
+
+ const uint col = block_id_m * TILESIZE_M;
+ const uint row = block_id_n * TILESIZE_N;
+ const uint depth_A = block_id_d / (D_B/D_A);
+ const uint depth_B = block_id_d;
+
+#ifdef KQV
+ int line_stride_matrix_A_in_bytes = nb01 * M;
+ int line_stride_matrix_B_in_bytes = K * N * 4;
+#else
+ int line_stride_matrix_A_in_bytes = K * D_A * 2;
+ int line_stride_matrix_B_in_bytes = K * D_B * 4;
+#endif
+
+ int line_stride_matrix_C_in_bytes = M * 4;
+
+ const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
+ const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
+
+ size_t sub_block_id_m = get_local_id(0);
+
+ uint b_localOffsetInWords = (sub_block_id_m/16)*16
+ + ((((sub_block_id_m)>>0)&1)<<2)
+ + ((((sub_block_id_m)>>1)&1)<<3)
+ + ((((sub_block_id_m)>>2)&1)<<0)
+ + ((((sub_block_id_m)>>3)&1)<<1);
+
+ uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
+ uint b_globalOffsetInWords00, b_globalOffsetInWords16;
+#ifdef KQV
+ b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
+ b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
+ uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
+ uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
+#else
+ b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
+ b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
+ uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
+ uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
+#endif
+
+ __local float matrix_B_local[1024];
+
+ for (uint step=0; step < K; step+=TILESIZE_K) {
+ size_t sub_block_id_m = get_local_id(0);
+ regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
+
+ uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
+ uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
+
+ regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
+ regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
+
+ mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1);
+
+ subMatrixAStartInElements += TILESIZE_K;
+ subMatrixBStartInElements += TILESIZE_K;
+ }
+
+ uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
+ mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
+}
+