constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
}
- const float max = simd_max(lmax);
+
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
- // whish to compute it twice.
+ // wish to compute it twice.
pdst[i00] = exp_psrc0;
}
- const float sum = simd_sum(lsum);
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum;
}
}
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- const float max = simd_max(lmax);
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
- const float sum = simd_sum(lsum);
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum;
}
}
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
- }
+ }
}
kernel void kernel_diag_mask_inf_8(