]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : mark FA blocks (llama/16372)
authorGeorgi Gerganov <redacted>
Wed, 8 Oct 2025 07:57:53 +0000 (10:57 +0300)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 04:57:25 +0000 (07:57 +0300)
* metal : better unroll in the FA kernels

* metal : index FA blocks

* tests : restore [no ci]

* metal : prevent division by zero in FA kernels

* metal : fix -INF detection logic

src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.cpp
src/ggml-metal/ggml-metal.metal

index 46cc51345969bf4180943939d1eee7af51b8a55b..e23abdda97405bc53af5d789f1b9d0b451af8ceb 100644 (file)
@@ -959,7 +959,53 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
   //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
   //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
   //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
-    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
+  //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
+    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
+
+    return res;
+}
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
+        ggml_metal_library_t lib,
+        const struct ggml_tensor * op,
+        int32_t nqptg,
+        int32_t ncpsg) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+    GGML_UNUSED(op);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_%s",
+            "flash_attn_ext_blk");
+
+    snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
+            base,
+            nqptg,
+            ncpsg);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+  //ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_BLK + 0);
+  //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
+  //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_BLK + 2);
+  //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_BLK + 3);
+
+  //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
+  //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
+  //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_BLK + 22);
+  //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_BLK + 23);
+    ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
+    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
 
     res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
index ef049507384d85cf3778cb7a78024a57762081e2..1034e4bbf65960af5c157f75a8e031670044a898 100644 (file)
@@ -141,6 +141,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         bool    has_mask,
         int32_t ncpsg);
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
+        ggml_metal_library_t lib,
+        const struct ggml_tensor * op,
+        int32_t nqptg,
+        int32_t ncpsg);
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
index 1524b3ab518cbf2704acefd745eea1996e210eea..c9dff873058697d8fdbd75c6e16191821bffe475 100644 (file)
 
 // function constants offsets
 #define FC_FLASH_ATTN_EXT_PAD          100
-#define FC_FLASH_ATTN_EXT              200
-#define FC_FLASH_ATTN_EXT_VEC          300
-#define FC_FLASH_ATTN_EXT_VEC_REDUCE   400
-#define FC_MUL_MV                      500
-#define FC_MUL_MM                      600
+#define FC_FLASH_ATTN_EXT_BLK          200
+#define FC_FLASH_ATTN_EXT              300
+#define FC_FLASH_ATTN_EXT_VEC          400
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE   500
+#define FC_MUL_MV                      600
+#define FC_MUL_MM                      700
+
+// op-specific constants
+#define OP_FLASH_ATTN_EXT_NQPTG 8
+#define OP_FLASH_ATTN_EXT_NCPSG 64
+
+#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
+#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
 // kernel argument structs
 //
@@ -263,6 +271,17 @@ typedef struct {
     uint64_t nb33;
 } ggml_metal_kargs_flash_attn_ext_pad;
 
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne30;
+    int32_t  ne31;
+    int32_t  ne32;
+    int32_t  ne33;
+    uint64_t nb31;
+    uint64_t nb32;
+    uint64_t nb33;
+} ggml_metal_kargs_flash_attn_ext_blk;
+
 typedef struct {
     int32_t  ne01;
     int32_t  ne02;
index 125cc64dc5295f2bfb6708543bf3514b3f8550c7..1137e210773af54dd1969c5a1249773745b18484 100644 (file)
@@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
     const bool has_mask = op->src[3] != nullptr;
 
     if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
-        const bool has_kvpad = ne11 % 32 != 0;
+        const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
 
         if (has_kvpad) {
-            res += 32*(
+            res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
                 nb11*ne12*ne13 +
                 nb21*ne22*ne23 +
                 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
         }
     } else {
-        const bool has_kvpad = ne11 % 64 != 0;
+        const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
 
         if (has_kvpad) {
-            res += 64*(
+            res += OP_FLASH_ATTN_EXT_NCPSG*(
                 nb11*ne12*ne13 +
                 nb21*ne22*ne23 +
                 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
@@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
     return res;
 }
 
+size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+  //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+  //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+  //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+    size_t res = 0;
+
+    const bool has_mask = op->src[3] != nullptr;
+
+    if (!has_mask) {
+        return res;
+    }
+
+    const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
+
+    // this optimization is not useful for the vector kernels
+    if (is_vec) {
+        return res;
+    }
+
+    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
+    const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
+
+    const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
+    const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
+
+    res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
+
+    return res;
+}
+
 size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
     assert(op->op == GGML_OP_FLASH_ATTN_EXT);
 
@@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
     ggml_metal_buffer_id bid_pad = bid_dst;
     bid_pad.offs += ggml_nbytes(op);
 
-    ggml_metal_buffer_id bid_tmp = bid_pad;
-    bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
+    ggml_metal_buffer_id bid_blk = bid_pad;
+    bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
+
+    ggml_metal_buffer_id bid_tmp = bid_blk;
+    bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
 
     if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
         // half8x8 kernel
-        const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
-        const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
+        const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
+        const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
 
         GGML_ASSERT(nqptg <= 32);
         GGML_ASSERT(nqptg  % 8  == 0);
         GGML_ASSERT(ncpsg  % 32 == 0);
 
+        bool need_sync = false;
+
         const bool has_kvpad = ne11 % ncpsg != 0;
 
         if (has_kvpad) {
@@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
             ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
 
-            ggml_metal_op_concurrency_reset(ctx);
+            need_sync = true;
         } else {
             assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
         }
 
+        if (has_mask) {
+            assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
+
+            ggml_metal_kargs_flash_attn_ext_blk args0 = {
+                /*.ne01 =*/ ne01,
+                /*.ne30 =*/ ne30,
+                /*.ne31 =*/ ne31,
+                /*.ne32 =*/ ne32,
+                /*.ne33 =*/ ne33,
+                /*.nb31 =*/ nb31,
+                /*.nb32 =*/ nb32,
+                /*.nb33 =*/ nb33,
+            };
+
+            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
+
+            ggml_metal_encoder_set_pipeline(enc, pipeline0);
+            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
+            ggml_metal_encoder_set_buffer  (enc, bid_src3, 1);
+            ggml_metal_encoder_set_buffer  (enc, bid_blk,  2);
+
+            const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
+            const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
+
+            ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
+
+            need_sync = true;
+        } else {
+            assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
+        }
+
+        if (need_sync) {
+            ggml_metal_op_concurrency_reset(ctx);
+        }
+
         const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
 
         // 2*(2*ncpsg)
@@ -2164,7 +2242,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);
         ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);
         ggml_metal_encoder_set_buffer  (enc, bid_pad,  6);
-        ggml_metal_encoder_set_buffer  (enc, bid_dst,  7);
+        ggml_metal_encoder_set_buffer  (enc, bid_blk,  7);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  8);
 
         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
@@ -2172,14 +2251,16 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 #undef FATTN_SMEM
     } else {
         // half4x4 kernel
-        const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
-        const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-        const int64_t nkpsg = 1*ncpsg;
+        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
+        const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
+        const int nkpsg = 1*ncpsg;
 
         GGML_ASSERT(nqptg <= 32);
         GGML_ASSERT(nqptg  % 1  == 0);
         GGML_ASSERT(ncpsg  % 32 == 0);
 
+        bool need_sync = false;
+
         const bool has_kvpad = ne11 % ncpsg != 0;
 
         if (has_kvpad) {
@@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
             ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
 
-            ggml_metal_op_concurrency_reset(ctx);
+            need_sync = true;
         } else {
             assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
         }
 
+        if (need_sync) {
+            ggml_metal_op_concurrency_reset(ctx);
+        }
+
         // ne00 + 2*ncpsg*(nsg)
         // for each query, we load it as f16 in shared memory (ne00)
         // and store the soft_max values and the mask
index 6a6d8a7977a7c25b2073629b2e0f9a25ed78d14b..d4cb9446212d90c5f89191878669dff0ce03a6a2 100644 (file)
@@ -40,6 +40,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
 bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
 
 size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
+size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
 size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
 
 int ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);
index e53f37b29c1a4a48d5c85b8f03e87c70048280a8..7afc881fa7012f414533e32425b0daeff2787312 100644 (file)
@@ -194,6 +194,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
+                res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
                 res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
             } break;
         default:
index c52c6b48ad900a015fba9769cf6d2e8585ab1d69..45d91def88bf211ab0381117ff2b672480391233 100644 (file)
@@ -4351,7 +4351,7 @@ kernel void kernel_leaky_relu_f32_4(
 
 constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
 
-constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
+constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
 
 // pad the last chunk of C elements of k and v into a an extra pad buffer
 kernel void kernel_flash_attn_ext_pad(
@@ -4419,6 +4419,65 @@ kernel void kernel_flash_attn_ext_pad(
     }
 }
 
+constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
+constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
+
+// scan the blocks of the mask that are not masked
+// 0 -     masked (i.e. full of -INF, skip)
+// 1 - not masked (i.e. at least one element of the mask is not -INF)
+kernel void kernel_flash_attn_ext_blk(
+        constant ggml_metal_kargs_flash_attn_ext_blk & args,
+        device const char * mask,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+    // block size C x Q
+    const int32_t Q = FC_flash_attn_ext_blk_nqptg;
+    const int32_t C = FC_flash_attn_ext_blk_ncpsg;
+
+    constexpr short NW  = N_SIMDWIDTH;
+
+    const int32_t i3 = tgpig[2]/args.ne32;
+    const int32_t i2 = tgpig[2]%args.ne32;
+    const int32_t i1 = tgpig[1];
+    const int32_t i0 = tgpig[0];
+
+    char res = i0*C + C > args.ne30 ? 1 : 0;
+
+    device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
+
+    // fast route
+    if (res == 0) {
+        if (simd_max(*mask_src) > -MAXHALF/2) {
+            res = 1;
+        }
+    }
+
+    // detailed check of the elements of the block
+    if ((C > NW || Q > 1) && res == 0) {
+        half m = -MAXHALF;
+
+        FOR_UNROLL (short j = 0; j < Q; ++j) {
+            FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
+                m = max(m, mask_src[ii*NW]);
+            }
+
+            mask_src += args.nb31/2;
+        }
+
+        if (simd_max(m) > -MAXHALF/2) {
+            res = 1;
+        }
+    }
+
+    const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
+    const int32_t nblk0 = ((args.ne30 + C - 1)/C);
+
+    if (tiisg == 0) {
+        dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
+    }
+}
+
 constant bool FC_flash_attn_ext_has_mask  [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
 constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
 constant bool FC_flash_attn_ext_has_bias  [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
@@ -4473,6 +4532,7 @@ void kernel_flash_attn_ext_impl(
         device const char * mask,
         device const char * sinks,
         device const char * pad,
+        device const char * blk,
         device       char * dst,
         threadgroup  half * shmem_f16,
         uint3   tgpig,
@@ -4538,6 +4598,13 @@ void kernel_flash_attn_ext_impl(
         pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
     }
 
+    {
+        const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
+        const int32_t nblk0 = ((args.ne11 + C - 1)/C);
+
+        blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
+    }
+
     {
         q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
 
@@ -4597,11 +4664,14 @@ void kernel_flash_attn_ext_impl(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
-            int ic = ic0;
+        for (int ic0 = 0; ; ++ic0) {
+            int ic = ic0*C;
+            if (ic >= args.ne11) {
+                break;
+            }
 
             // the last partial chunk uses the pad buffer as source
-            if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
+            if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
                 k    = pad;
                 v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
                 mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
@@ -4640,6 +4710,14 @@ void kernel_flash_attn_ext_impl(
 
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
+                if (blk[ic0] == 0) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        pm2[jj] += NW;
+                    }
+
+                    continue;
+                }
+
                 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                     const short j = jj*NSG + sgitg;
 
@@ -4652,6 +4730,9 @@ void kernel_flash_attn_ext_impl(
                     pm2[jj] += NW;
                 }
 
+#if 0
+                // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
+
                 threadgroup_barrier(mem_flags::mem_threadgroup);
 
                 // used to detect blocks full of -INF
@@ -4670,6 +4751,7 @@ void kernel_flash_attn_ext_impl(
 
                     continue;
                 }
+#endif
             }
 
             // Q*K^T
@@ -4687,26 +4769,24 @@ void kernel_flash_attn_ext_impl(
 
                 constexpr short NC = (C/8)/NSG;
 
-                // TODO: not good to unroll for large contexts - not sure why?
+                // note: do not unroll for large heads
+                #pragma unroll (DK <= 64 ? NC : 1)
                 for (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 
-                    if (DK8 % 16 != 0) {
+                    if (DK % 16 != 0) {
                         k8x8_t mk;
                         q8x8_t mq;
 
                         FOR_UNROLL (short i = 0; i < DK8; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
-                            simdgroup_load(mk, pk, NS10, 0, true);
-                            simdgroup_load(mq, pq, DK);
+                            simdgroup_load(mk, pk + 8*i, NS10, 0, true);
+                            simdgroup_load(mq, pq + 8*i, DK);
 
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
-
-                            pk += 8;
-                            pq += 8;
                         }
                     } else {
                         k8x8_t mk[2];
@@ -4715,26 +4795,22 @@ void kernel_flash_attn_ext_impl(
                         FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
-                            simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
-                            simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
+                            simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
+                            simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
 
-                            simdgroup_load(mq[0], pq + 0*8, DK);
-                            simdgroup_load(mq[1], pq + 1*8, DK);
+                            simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
+                            simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
 
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
                             simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
-
-                            pk += 16;
-                            pq += 16;
                         }
                     }
 
                     simdgroup_store(mqk, ps, SH, 0, false);
 
-                    pk += 8*(NSG*NS10 - DK8);
-                    pq += 8*(NSG*0    - DK8);
+                    pk += 8*(NSG*NS10);
                     ps += 8*(NSG);
                 }
             } else {
@@ -4868,27 +4944,50 @@ void kernel_flash_attn_ext_impl(
                     }
 
                     {
-                        auto sst = ss;
-
                         device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
 
                         pv += 8*sgitg;
 
-                        FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
-                            s8x8_t vs;
-                            simdgroup_load(vs, sst, SH, 0, false);
+                        if (DV <= 64) {
+                            FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
+                                s8x8_t vs;
+                                simdgroup_load(vs, ss + 8*cc, SH, 0, false);
 
-                            FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
-                                v8x8_t mv;
+                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
+                                    v8x8_t mv[2];
 
-                                simdgroup_load(mv, pv, NS20, 0, false);
-                                simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
+                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
+                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
 
-                                pv += 8*NSG;
+                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
+                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
+                                }
+
+                                pv  += 8*NS20;
                             }
+                        } else {
+                            FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
+                                s8x8_t vs[2];
+
+                                simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
+                                simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
 
-                            pv  += 8*(NS20 - NO*NSG);
-                            sst += 8;
+                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
+                                    v8x8_t mv[4];
+
+                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
+                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
+                                    simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
+                                    simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
+
+                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
+                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
+                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
+                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
+                                }
+
+                                pv  += 2*8*NS20;
+                            }
                         }
                     }
 
@@ -5002,7 +5101,7 @@ void kernel_flash_attn_ext_impl(
 
         device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
 
-        const float scale = 1.0f/S[jj];
+        const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
 
         if (DV4 % NW == 0) {
             FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
@@ -5047,8 +5146,8 @@ template<
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
     short DK,         // K head size
     short DV,         // V head size
-    short Q  = 8,     // queries per threadgroup
-    short C  = 64>    // cache items per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         constant ggml_metal_kargs_flash_attn_ext & args,
         device const char * q,
@@ -5057,13 +5156,14 @@ kernel void kernel_flash_attn_ext(
         device const char * mask,
         device const char * sinks,
         device const char * pad,
+        device const char * blk,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
 #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
+#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
     switch (FC_flash_attn_ext_nsg) {
       // note: disabled cases to reduce library load time
       //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
@@ -5210,9 +5310,9 @@ template<
     void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
     short DK,       // K head size
     short DV,       // V head size
-    short NE = 4,   // head elements per thread
-    short Q  = 1,   // queries per threadgroup
-    short C  = 32,  // cache items per threadgroup
+    short NE,       // head elements per thread
+    short Q,        // queries per threadgroup
+    short C,        // cache items per threadgroup
     short NSG>      // number of simd groups
 void kernel_flash_attn_ext_vec_impl(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
@@ -5327,8 +5427,8 @@ void kernel_flash_attn_ext_vec_impl(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
-            int ic = ic0 + C*sgitg;
+        for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
+            int ic = ic0*C;
             if (ic >= args.ne11) {
                 break;
             }
@@ -5621,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl(
         device float4 * dst4 = (device float4 *) dst;
         device float  * dst1 = (device float  *) dst + nrows*DV*NWG; // the S and M are stored after the results
 
-        const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
+        const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
 
         // interleave the workgroup data
         for (short i = tiisg; i < DV4; i += NW) {
@@ -5659,8 +5759,8 @@ template<
     short DK,       // K head size
     short DV,       // V head size
     short NE = 4,   // head elements per thread
-    short Q  = 1,   // queries per threadgroup
-    short C  = 32>  // cache items per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPTG,  // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
@@ -5799,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
     const float m  = simd_max(M);
     const float ms = exp(M - m);
 
-    S = 1.0f/simd_sum(S*ms);
+    S = simd_sum(S*ms);
+    S = S == 0.0f ? 0.0f : 1.0f/S;
 
     const short DV4 = DV/4;