device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
+ if (ne00 < 128) {
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
}
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
}
#define N_F16_F32 4
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
- const int64_t rb = N_F16_F32*tgpig.y;
+ const int64_t rb = tgpig.y*N_F16_F32;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
- for (int row = 0; row < N_F16_F32; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
}
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
}
}
+
}
kernel void kernel_alibi_f32(