// scan the blocks of the mask that are not masked
// 0 - masked (i.e. full of -INF, skip)
// 1 - not masked (i.e. at least one element of the mask is not -INF)
+// 2 - all zero
kernel void kernel_flash_attn_ext_blk(
constant ggml_metal_kargs_flash_attn_ext_blk & args,
device const char * mask,
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
- // fast route
- if (res == 0) {
- if (simd_max(*mask_src) > -MAXHALF/2) {
- res = 1;
- }
- }
-
// detailed check of the elements of the block
if ((C > NW || Q > 1) && res == 0) {
- half m = -MAXHALF;
+ half mmin = MAXHALF;
+ half mmax = -MAXHALF;
FOR_UNROLL (short j = 0; j < Q; ++j) {
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
- m = max(m, mask_src[ii*NW]);
+ mmin = min(mmin, mask_src[ii*NW]);
+ mmax = max(mmax, mask_src[ii*NW]);
}
mask_src += args.nb31/2;
}
- if (simd_max(m) > -MAXHALF/2) {
- res = 1;
+ mmin = simd_min(mmin);
+ mmax = simd_max(mmax);
+
+ if (mmax > -MAXHALF) {
+ if (mmin == 0.0 && mmax == 0.0) {
+ res = 2;
+ } else {
+ res = 1;
+ }
}
}
ic = 0;
}
+ char blk_cur = 1;
+
// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
- if (blk[ic0] == 0) {
+ blk_cur = blk[ic0];
+
+ if (blk_cur == 0) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
continue;
}
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
+ if (blk_cur == 1) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- if (FC_flash_attn_ext_bc_mask) {
- sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
- } else {
- sm2[j*SH + tiisg] = pm2[jj][tiisg];
- }
+ if (FC_flash_attn_ext_bc_mask) {
+ sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+ } else {
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
+ }
- pm2[jj] += NW;
+ pm2[jj] += NW;
+ }
+ } else if (blk_cur == 2) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ pm2[jj] += NW;
+ }
}
#if 0
}
// mqk = mqk + slope*mask
- if (FC_flash_attn_ext_has_bias) {
- s2 += s2_t(sm2[j*SH + tiisg])*slope;
- } else {
- s2 += s2_t(sm2[j*SH + tiisg]);
+ if (blk_cur != 2) {
+ if (FC_flash_attn_ext_has_bias) {
+ s2 += s2_t(sm2[j*SH + tiisg])*slope;
+ } else {
+ s2 += s2_t(sm2[j*SH + tiisg]);
+ }
}
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));