GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
default: break;
}
- const int64_t neh10 = ne10; // n_embd
- const int64_t neh11 = ne21; // n_tokens
- const int64_t neh12 = ne02; // n_expert
-
- const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
- const uint64_t nbh11 = nbh10*neh10;
- const uint64_t nbh12 = nbh11*neh11;
- const uint64_t nbh13 = nbh12*neh12;
-
- const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
- id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
- if (!h_src1) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
- return 0;
- }
-
- const int64_t neh0 = ne0;
- const int64_t neh1 = ne21;
- const int64_t neh2 = ne02;
-
- const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
- const uint64_t nbh1 = nbh0*neh0;
- const uint64_t nbh2 = nbh1*neh1;
- //const uint64_t nbh3 = nbh2*neh2;
-
- const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
- id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
- if (!h_dst) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
- return 0;
- }
-
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
}
// id map
- // [n_expert_used, n_tokens]
- const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
+ // [n_tokens, n_expert]
+ const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
}
{
- const int nth = MIN(1024, ne10/4);
-
ggml_metal_kargs_mul_mm_id_map0 args = {
+ ne02,
ne10,
- ne11, // n_expert_used (bcast)
+ ne11, // n_expert_used (bcast)
nb11,
nb12,
- neh11, // n_tokens
- nbh11,
- ne20, // n_expert_used
+ ne21, // n_tokens
+ ne20, // n_expert_used
nb21,
};
id<MTLComputePipelineState> pipeline = nil;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
+ pipeline = nil;
+
+ switch (ne20) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+ case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
+ default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
+ }
+
+ GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ const size_t smem = ne02*ne20*sizeof(uint16_t);
+
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- [encoder setBuffer: h_src1 offset:0 atIndex:3];
- [encoder setBuffer: h_tpe offset:0 atIndex:4];
- [encoder setBuffer: h_ids offset:0 atIndex:5];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
+ [encoder setBuffer: h_tpe offset:0 atIndex:2];
+ [encoder setBuffer: h_ids offset:0 atIndex:3];
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
{
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
- /*.neh12 =*/ neh12,
- /*.nbh10 =*/ nbh10,
- /*.nbh11 =*/ nbh11,
- /*.nbh12 =*/ nbh12,
- /*.nbh13 =*/ nbh13,
- /*.neh0 =*/ neh0,
- /*.neh1 =*/ neh1,
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne20 =*/ ne20, // n_expert_used
+ /*.ne21 =*/ ne21, // n_tokens
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer: h_src1 offset:0 atIndex:2];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
- [encoder setBuffer: h_dst offset:0 atIndex:4];
+ [encoder setBuffer: h_ids offset:0 atIndex:4];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
-
- {
- GGML_ASSERT(ne0 % 4 == 0);
-
- const int nth = MIN(1024, ne0/4);
-
- ggml_metal_kargs_mul_mm_id_map1 args = {
- ne20, // n_expert_used
- neh0,
- neh1,
- nbh1,
- nbh2,
- ne0,
- nb1,
- nb2,
- };
-
- id<MTLComputePipelineState> pipeline = nil;
-
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer: h_dst offset:0 atIndex:1];
- [encoder setBuffer: h_ids offset:0 atIndex:2];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
} else {
id<MTLComputePipelineState> pipeline = nil;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = 1.0f / *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
}
}
-template<typename T4>
+template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
- device const char * src1,
device const char * src2,
- device char * hsrc1,
device char * htpe,
device char * hids,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int ide = tgpig[0]; // expert id
-
- int n_all = 0;
+ threadgroup char * shmem [[threadgroup(0)]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const short ide = tpitg; // expert id
- device int32_t * ids_i32 = (device int32_t *) (hids);
+ uint32_t n_all = 0;
- for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
+ device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
- for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
- if (src2_i32[i20] != ide) {
- continue;
- }
+ for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
+ if (i21 + tpitg < args.ne21) {
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
- device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
- device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
- for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
- hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sids[i20] = src2_i32[i20];
}
-
- if (tpitg.x == 0) {
- ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
- }
-
- ++n_all;
}
- }
- if (tpitg.x == 0) {
- device int32_t * tpe_i32 = (device int32_t *) (htpe);
- tpe_i32[ide] = n_all;
- }
-}
-
-typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
-
-template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
-template<typename T>
-kernel void kernel_mul_mm_id_map1(
- constant ggml_metal_kargs_mul_mm_id_map1 & args,
- device const char * hdst,
- device const char * hids,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i20 = tgpig[0]; // used expert
- const int i21 = tgpig[1]; // token
+ for (short t = 0; t < ntg; t++) {
+ if (i21 + t >= args.ne21) {
+ break;
+ }
- device const int32_t * ids_i32 = (device const int32_t *) (hids);
- device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
+ threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
- const int id = ids_i32[i21*args.ne20 + i20];
+ short sel = 0;
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sel += (sids[i20] == ide)*(i20 + 1);
+ }
- const int ide = id / args.neh1;
- const int idt = id % args.neh1;
+ ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
- device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
+ n_all += sel > 0;
+ }
- for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
- dst_f32x4[i0] = hdst_f32x4[i0];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
}
+
+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
+ tpe_u32[ide] = n_all;
}
-typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
+typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
-template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
- device const char * tpe,
+ device const char * htpe,
+ device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
- const int im = tgpig.z;
+ const int im = tgpig.z; // expert
- device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
- const int neh1 = tpe_i32[im];
+ const int32_t neh1 = tpe_u32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
- const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short il = (tiitg % THREAD_PER_ROW);
- const int i12 = im%args.neh12;
- const int i13 = im/args.neh12;
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const short i11 = (id % args.ne20) % args.ne11;
+ const short i12 = (id / args.ne20);
+ const short i13 = 0;
+
+ const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
- device const half * y = (device const half *)(src1
- + args.nbh13*i13
- + args.nbh12*i12
- + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
- + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ device const float * y = (device const float *)(src1
+ + args.nb13*i13
+ + args.nb12*i12
+ + args.nb11*i11
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
- *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
}
}
- if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
- device float * C = (device float *) dst +
- (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
- (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
- }
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
- }
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ #pragma unroll(8)
+ for (short i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+ }
- if (sgitg == 0) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
- device float4 * D4 = (device float4 *) D;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
- threadgroup float4 * C4 = (threadgroup float4 *) C;
+ for (short j = sgitg; j < n_cols; j += 4) {
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
- int i = 0;
- for (; i < n_rows/4; i++) {
- *(D4 + i) = *(C4 + i);
- }
+ const short ide = id % args.ne20;
+ const short idt = id / args.ne20;
- i *= 4;
- for (; i < n_rows; i++) {
- *(D + i) = *(C + i);
- }
- }
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = tiisg;
+ for (; i < n_rows/4; i += 32) {
+ *(D4 + i) = *(C4 + i);
+ }
+
+ i = (4*(n_rows/4)) + tiisg;
+ for (; i < n_rows; i += 32) {
+ *(D + i) = *(C + i);
}
}
}