}
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
-
- const short tx = tiisg%4;
- const short ty = tiisg/4;
-
- // mqk = mqk*scale
- ss[8*cc + ty*TF + 2*tx + 0] *= scale;
- ss[8*cc + ty*TF + 2*tx + 1] *= scale;
-
- if (logit_softcap != 0.0f) {
- ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
- ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
- }
-
- if (mask != q) {
- // mqk = mqk + mask*slope
- ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
- ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
- }
}
}
float ms[Q];
for (short j = 0; j < Q; ++j) {
- const short p = tiisg;
-
const float m = M[j];
- const float s = ss[j*TF + p];
+
+ // scale and apply the logitcap / mask
+ float s = ss[j*TF + tiisg]*scale;
+
+ if (logit_softcap != 0.0f) {
+ s = logit_softcap*precise::tanh(s);
+ }
+
+ if (mask != q) {
+ // mqk = mqk + mask*slope
+ s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
+ }
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
- ss[j*TF + p] = vs;
+ ss[j*TF + tiisg] = vs;
}
// create a QxQ diagonal matrix for rescaling the output