]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : skip loading all-zero mask (llama/19337)
authorGeorgi Gerganov <redacted>
Fri, 6 Feb 2026 07:25:11 +0000 (09:25 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
* metal : skip loading all-zero mask

* cont : minor

src/ggml-metal/ggml-metal.metal

index e54cdab39ddcc4476b20d3c6bbf0ba7dd00191ae..612a42a1ea86b69c495d8693209b9e32a596c29c 100644 (file)
@@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
 // scan the blocks of the mask that are not masked
 // 0 -     masked (i.e. full of -INF, skip)
 // 1 - not masked (i.e. at least one element of the mask is not -INF)
+// 2 - all zero
 kernel void kernel_flash_attn_ext_blk(
         constant ggml_metal_kargs_flash_attn_ext_blk & args,
         device const char * mask,
@@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk(
 
     device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
 
-    // fast route
-    if (res == 0) {
-        if (simd_max(*mask_src) > -MAXHALF/2) {
-            res = 1;
-        }
-    }
-
     // detailed check of the elements of the block
     if ((C > NW || Q > 1) && res == 0) {
-        half m = -MAXHALF;
+        half mmin =  MAXHALF;
+        half mmax = -MAXHALF;
 
         FOR_UNROLL (short j = 0; j < Q; ++j) {
             FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
-                m = max(m, mask_src[ii*NW]);
+                mmin = min(mmin, mask_src[ii*NW]);
+                mmax = max(mmax, mask_src[ii*NW]);
             }
 
             mask_src += args.nb31/2;
         }
 
-        if (simd_max(m) > -MAXHALF/2) {
-            res = 1;
+        mmin = simd_min(mmin);
+        mmax = simd_max(mmax);
+
+        if (mmax > -MAXHALF) {
+            if (mmin == 0.0 && mmax == 0.0) {
+                res = 2;
+            } else {
+                res = 1;
+            }
         }
     }
 
@@ -5568,9 +5571,13 @@ void kernel_flash_attn_ext_impl(
                 ic = 0;
             }
 
+            char blk_cur = 1;
+
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
-                if (blk[ic0] == 0) {
+                blk_cur = blk[ic0];
+
+                if (blk_cur == 0) {
                     FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                         pm2[jj] += NW;
                     }
@@ -5578,16 +5585,22 @@ void kernel_flash_attn_ext_impl(
                     continue;
                 }
 
-                FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
-                    const short j = jj*NSG + sgitg;
+                if (blk_cur == 1) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
 
-                    if (FC_flash_attn_ext_bc_mask) {
-                        sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
-                    } else {
-                        sm2[j*SH + tiisg] = pm2[jj][tiisg];
-                    }
+                        if (FC_flash_attn_ext_bc_mask) {
+                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+                        } else {
+                            sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        }
 
-                    pm2[jj] += NW;
+                        pm2[jj] += NW;
+                    }
+                } else if (blk_cur == 2) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        pm2[jj] += NW;
+                    }
                 }
 
 #if 0
@@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl(
                 }
 
                 // mqk = mqk + slope*mask
-                if (FC_flash_attn_ext_has_bias) {
-                    s2 += s2_t(sm2[j*SH + tiisg])*slope;
-                } else {
-                    s2 += s2_t(sm2[j*SH + tiisg]);
+                if (blk_cur != 2) {
+                    if (FC_flash_attn_ext_has_bias) {
+                        s2 += s2_t(sm2[j*SH + tiisg])*slope;
+                    } else {
+                        s2 += s2_t(sm2[j*SH + tiisg]);
+                    }
                 }
 
                 M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));