const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
- float4 mq[D4];
+ float4 mq[D4/NW];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
- mq[i] = (float4) sq4[i];
+ mq[ii/NW] = (float4) sq4[i];
}
// pointer to the mask
mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)];
- mqk += (float4) (mq[i] * mk);
+ mqk += (float4) (mq[ii/NW] * mk);
}
// reduce the results from the threads in the simdgroup
// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
- lo[i/NW] *= ms;
+ lo[ii/NW] *= ms;
}
}
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
- lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
- lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
- lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
- lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+ lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+ lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+ lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}