]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : refactor + optimize (llama/15857)
authorGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:04:02 +0000 (13:04 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index b9d36394485009f48990a9775a76bd86fccf4e6f..651943fa923804677a9d093f7af962f500f07684 100644 (file)
@@ -20,8 +20,8 @@
 #define N_R0_Q5_1 4
 #define N_SG_Q5_1 2
 
-#define N_R0_Q8_0 4
-#define N_SG_Q8_0 2
+#define N_R0_Q8_0 2
+#define N_SG_Q8_0 4
 
 #define N_R0_MXFP4 2
 #define N_SG_MXFP4 2
 #define N_R0_IQ4_XS 2
 #define N_SG_IQ4_XS 2
 
+// function constants offsets
+#define FC_FLASH_ATTN_EXT              100
+#define FC_FLASH_ATTN_EXT_VEC          200
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE   300
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
     int32_t  ne11;
     int32_t  ne_12_2; // assume K and V are same shape
     int32_t  ne_12_3;
+    int32_t  ns10;
     uint64_t nb11;
     uint64_t nb12;
     uint64_t nb13;
+    int32_t  ns20;
     uint64_t nb21;
     uint64_t nb22;
     uint64_t nb23;
@@ -258,10 +265,43 @@ typedef struct {
     float    logit_softcap;
 } ggml_metal_kargs_flash_attn_ext;
 
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    int32_t  ns10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ns20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ne32;
+    int32_t  ne33;
+    uint64_t nb31;
+    uint64_t nb32;
+    uint64_t nb33;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    int32_t  n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext_vec;
+
 typedef struct {
     int32_t  nrows;
-    int32_t  ne20;
-} ggml_metal_kargs_flash_attn_ext_reduce;
+} ggml_metal_kargs_flash_attn_ext_vec_reduce;
 
 typedef struct {
     int32_t  ne00;
index 578bdd6ecaa3a447df3c70ebbc0f3d85f1430df6..eeb6c9d4b70f2454899430c02aaa2fe1738b829a 100644 (file)
@@ -174,6 +174,19 @@ struct ggml_metal_kernel {
     id<MTLComputePipelineState> pipeline;
 };
 
+@interface ggml_metal_kernel_wrapper : NSObject
+
+@property (nonatomic, assign) struct ggml_metal_kernel kernel;
+
+@end
+
+@implementation ggml_metal_kernel_wrapper
+- (void) dealloc {
+    [_kernel.pipeline release];
+    [super dealloc];
+}
+@end
+
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
     GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
@@ -454,126 +467,6 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
     GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
     GGML_METAL_KERNEL_TYPE_SET_I32,
     GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -884,8 +777,12 @@ struct ggml_backend_metal_context {
 
     dispatch_queue_t d_queue;
 
+    // the set of pre-compiled kernels for this context
     struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
 
+    // additional, inference-time compiled kernels
+    NSMutableDictionary * kernels_ext;
+
     // capture state
     bool capture_next_compute;
     bool capture_started;
@@ -951,6 +848,8 @@ static void * ggml_metal_host_malloc(size_t n) {
 // - if not found, load the source and compile it
 // - if that fails, return NULL
 static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
+    const int64_t t_start = ggml_time_us();
+
     id<MTLLibrary> metal_library = nil;
     NSError * error = nil;
     NSString * src = nil;
@@ -1074,6 +973,8 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl
     [src release];
 #endif // GGML_METAL_EMBED_LIBRARY
 
+    GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
+
     return metal_library;
 }
 
@@ -1271,7 +1172,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                   get_rows_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                   get_rows_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                   get_rows_q8_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,                  get_rows_mxfp4,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,                  get_rows_mxfp4,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                   get_rows_q2_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                   get_rows_q3_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                   get_rows_q4_K,                   true);
@@ -1489,126 +1390,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,             argsort_f32_i32_asc,             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,            argsort_f32_i32_desc,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                  leaky_relu_f32,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,          flash_attn_ext_f16_h40,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,          flash_attn_ext_f16_h64,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,          flash_attn_ext_f16_h80,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,          flash_attn_ext_f16_h96,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,         flash_attn_ext_f16_h112,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,         flash_attn_ext_f16_h128,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,         flash_attn_ext_f16_h192,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,  flash_attn_ext_f16_hk192_hv128,  has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,         flash_attn_ext_f16_h256,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,  flash_attn_ext_f16_hk576_hv512,  has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,         flash_attn_ext_bf16_h40,         has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,         flash_attn_ext_bf16_h64,         has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,         flash_attn_ext_bf16_h80,         has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,         flash_attn_ext_bf16_h96,         has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,        flash_attn_ext_bf16_h112,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,        flash_attn_ext_bf16_h128,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,        flash_attn_ext_bf16_h192,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,        flash_attn_ext_bf16_h256,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,         flash_attn_ext_q4_0_h40,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,         flash_attn_ext_q4_0_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,         flash_attn_ext_q4_0_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,         flash_attn_ext_q4_0_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,        flash_attn_ext_q4_0_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,        flash_attn_ext_q4_0_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,        flash_attn_ext_q4_0_h192,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,        flash_attn_ext_q4_0_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,         flash_attn_ext_q4_1_h40,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,         flash_attn_ext_q4_1_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,         flash_attn_ext_q4_1_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,         flash_attn_ext_q4_1_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,        flash_attn_ext_q4_1_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,        flash_attn_ext_q4_1_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,        flash_attn_ext_q4_1_h192,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,        flash_attn_ext_q4_1_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,         flash_attn_ext_q5_0_h40,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,         flash_attn_ext_q5_0_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,         flash_attn_ext_q5_0_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,         flash_attn_ext_q5_0_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,        flash_attn_ext_q5_0_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,        flash_attn_ext_q5_0_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,        flash_attn_ext_q5_0_h192,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,        flash_attn_ext_q5_0_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,         flash_attn_ext_q5_1_h40,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,         flash_attn_ext_q5_1_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,         flash_attn_ext_q5_1_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,         flash_attn_ext_q5_1_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,        flash_attn_ext_q5_1_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,        flash_attn_ext_q5_1_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,        flash_attn_ext_q5_1_h192,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,        flash_attn_ext_q5_1_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,         flash_attn_ext_q8_0_h40,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,         flash_attn_ext_q8_0_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,         flash_attn_ext_q8_0_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,         flash_attn_ext_q8_0_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,        flash_attn_ext_q8_0_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,        flash_attn_ext_q8_0_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,      flash_attn_ext_vec_f16_h64,      has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,     flash_attn_ext_vec_bf16_h64,     has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,     flash_attn_ext_vec_q4_0_h64,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,     flash_attn_ext_vec_q4_1_h64,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,     flash_attn_ext_vec_q5_0_h64,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,     flash_attn_ext_vec_q5_1_h64,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,     flash_attn_ext_vec_q8_0_h64,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,     flash_attn_ext_vec_q4_1_h96,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,     flash_attn_ext_vec_q5_0_h96,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,     flash_attn_ext_vec_q5_1_h96,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,     flash_attn_ext_vec_q8_0_h96,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,     flash_attn_ext_vec_f16_h128,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,    flash_attn_ext_vec_bf16_h128,    has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,    flash_attn_ext_vec_q4_0_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,    flash_attn_ext_vec_q4_1_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,    flash_attn_ext_vec_q5_0_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,    flash_attn_ext_vec_q5_1_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,    flash_attn_ext_vec_q8_0_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,     flash_attn_ext_vec_f16_h192,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,    flash_attn_ext_vec_bf16_h192,    has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,    flash_attn_ext_vec_q4_0_h192,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,    flash_attn_ext_vec_q4_1_h192,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,    flash_attn_ext_vec_q5_0_h192,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,    flash_attn_ext_vec_q5_1_h192,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,    flash_attn_ext_vec_q8_0_h192,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,     flash_attn_ext_vec_f16_hk192_hv128,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,    flash_attn_ext_vec_bf16_hk192_hv128,    has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,    flash_attn_ext_vec_q4_0_hk192_hv128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,    flash_attn_ext_vec_q4_1_hk192_hv128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,    flash_attn_ext_vec_q5_0_hk192_hv128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,    flash_attn_ext_vec_q5_1_hk192_hv128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,    flash_attn_ext_vec_q8_0_hk192_hv128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,     flash_attn_ext_vec_f16_h256,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,    flash_attn_ext_vec_bf16_h256,    has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,    flash_attn_ext_vec_q4_0_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,    flash_attn_ext_vec_q4_1_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,    flash_attn_ext_vec_q5_0_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,    flash_attn_ext_vec_q5_1_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,    flash_attn_ext_vec_q8_0_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,     flash_attn_ext_vec_f16_hk576_hv512,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,    flash_attn_ext_vec_bf16_hk576_hv512,    has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,    flash_attn_ext_vec_q4_0_hk576_hv512,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,    flash_attn_ext_vec_q4_1_hk576_hv512,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,    flash_attn_ext_vec_q5_0_hk576_hv512,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,    flash_attn_ext_vec_q5_1_hk576_hv512,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,    flash_attn_ext_vec_q8_0_hk576_hv512,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,           flash_attn_ext_reduce,           has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
@@ -1655,9 +1436,219 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
     }
 
+    ctx->kernels_ext = [[NSMutableDictionary alloc] init];
+
     return ctx;
 }
 
+static id<MTLComputePipelineState> ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) {
+    NSString * key = [NSString stringWithUTF8String:name];
+
+    ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key];
+    if (obj) {
+        return obj.kernel.pipeline;
+    }
+
+    return nil;
+}
+
+static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    id<MTLComputePipelineState> res = nil;
+
+    @autoreleasepool {
+        NSError * error = nil;
+
+        NSString * base_func = [NSString stringWithUTF8String:base];
+
+        GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name);
+
+        // TODO: make sure it is thread-safe to compile kernels in parallel
+        id<MTLFunction> metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error];
+        if (!metal_function) {
+            GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+
+            return nil;
+        }
+
+        struct ggml_metal_kernel kernel = {
+            /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error],
+        };
+
+        ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init];
+        obj.kernel = kernel;
+
+        res = obj.kernel.pipeline;
+
+        NSString * key = [NSString stringWithUTF8String:name];
+        [ctx->kernels_ext setObject:obj forKey:key];
+
+        GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
+                (int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
+                (int) kernel.pipeline.threadExecutionWidth);
+    }
+
+    return res;
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
+        ggml_backend_t backend, struct ggml_tensor * op,
+        bool    has_mask,
+        bool    has_sinks,
+        bool    has_bias,
+        bool    has_scap,
+        int32_t nsg) {
+    struct ggml_backend_metal_context * ctx = backend->context;
+
+    char base[256];
+    char name[256];
+
+    @autoreleasepool {
+        MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+        const int32_t dk = (int32_t) op->src[1]->ne[0];
+        const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+        const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+        const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+        snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+                "flash_attn_ext",
+                ggml_type_name(op->src[1]->type),
+                dk,
+                dv);
+
+        snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
+                "flash_attn_ext",
+                ggml_type_name(op->src[1]->type),
+                dk,
+                dv,
+                has_mask,
+                has_sinks,
+                has_bias,
+                has_scap,
+                ns10,
+                ns20,
+                nsg);
+
+        id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+        if (res) {
+            // kernel found
+            return res;
+        }
+
+        cv = [[MTLFunctionConstantValues alloc] init];
+
+        [cv setConstantValue:&has_mask  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
+        [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
+        [cv setConstantValue:&has_bias  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2];
+        [cv setConstantValue:&has_scap  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3];
+
+        [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20];
+        [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
+        [cv setConstantValue:&nsg  type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];
+
+        return ggml_metal_compile_kernel(backend, base, name, cv);
+    }
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec(
+        ggml_backend_t backend, struct ggml_tensor * op,
+        bool    has_mask,
+        bool    has_sinks,
+        bool    has_bias,
+        bool    has_scap,
+        int32_t nsg,
+        int32_t nwg) {
+    struct ggml_backend_metal_context * ctx = backend->context;
+
+    char base[256];
+    char name[256];
+
+    @autoreleasepool {
+        MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+        const int32_t dk = (int32_t) op->src[1]->ne[0];
+        const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+        const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+        const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+        snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+                "flash_attn_ext_vec",
+                ggml_type_name(op->src[1]->type),
+                dk,
+                dv);
+
+        snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
+                "flash_attn_ext_vec",
+                ggml_type_name(op->src[1]->type),
+                dk,
+                dv,
+                has_mask,
+                has_sinks,
+                has_bias,
+                has_scap,
+                ns10,
+                ns20,
+                nsg, nwg);
+
+        id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+        if (res) {
+            // kernel found
+            return res;
+        }
+
+        cv = [[MTLFunctionConstantValues alloc] init];
+
+        [cv setConstantValue:&has_mask  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
+        [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
+        [cv setConstantValue:&has_bias  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2];
+        [cv setConstantValue:&has_scap  type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3];
+
+        [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20];
+        [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21];
+        [cv setConstantValue:&nsg  type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
+        [cv setConstantValue:&nwg  type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];
+
+        return ggml_metal_compile_kernel(backend, base, name, cv);
+    }
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(
+        ggml_backend_t backend, struct ggml_tensor * op,
+        int32_t dv,
+        int32_t nwg) {
+    struct ggml_backend_metal_context * ctx = backend->context;
+
+    char base[256];
+    char name[256];
+
+    @autoreleasepool {
+        MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+        snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
+        snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
+
+        id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+        if (res) {
+            // kernel found
+            return res;
+        }
+
+        cv = [[MTLFunctionConstantValues alloc] init];
+
+        [cv setConstantValue:&dv  type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
+        [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];
+
+        return ggml_metal_compile_kernel(backend, base, name, cv);
+    }
+
+    GGML_UNUSED(op);
+}
+
 static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
     GGML_LOG_INFO("%s: deallocating\n", __func__);
 
@@ -1665,6 +1656,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
         [ctx->kernels[i].pipeline release];
     }
 
+    if (ctx->kernels_ext) {
+        [ctx->kernels_ext release];
+        ctx->kernels_ext = nil;
+    }
+
     Block_release(ctx->encode_async);
 
     [ctx->queue release];
@@ -3772,6 +3768,7 @@ static int ggml_metal_encode_node(
                             {
                                 nsg = N_SG_Q8_0;
                                 nr0 = N_R0_Q8_0;
+                                smem = 32*sizeof(float)*N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_MXFP4:
@@ -3908,7 +3905,12 @@ static int ggml_metal_encode_node(
                     if (smem > 0) {
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+                    if (src0t == GGML_TYPE_Q8_0) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    } else {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    }
                 }
             } break;
         case GGML_OP_MUL_MAT_ID:
@@ -4129,6 +4131,7 @@ static int ggml_metal_encode_node(
                             {
                                 nsg = N_SG_Q8_0;
                                 nr0 = N_R0_Q8_0;
+                                smem = 32*sizeof(float)*N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_MXFP4:
@@ -4274,7 +4277,12 @@ static int ggml_metal_encode_node(
                     if (smem > 0) {
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+                    if (src0t == GGML_TYPE_Q8_0) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    } else {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    }
                 }
             } break;
         case GGML_OP_GET_ROWS:
@@ -5125,6 +5133,7 @@ static int ggml_metal_encode_node(
                 float scale;
                 float max_bias;
                 float logit_softcap;
+
                 memcpy(&scale,         ((const int32_t *) dst->op_params) + 0, sizeof(scale));
                 memcpy(&max_bias,      ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
                 memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
@@ -5133,398 +5142,24 @@ static int ggml_metal_encode_node(
                     scale /= logit_softcap;
                 }
 
+                const bool has_mask  = src3 != NULL;
+                const bool has_sinks = src4 != NULL;
+                const bool has_bias  = max_bias != 0.0f;
+                const bool has_scap  = logit_softcap != 0.0f;
+
                 const uint32_t n_head      = src0->ne[2];
                 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                id<MTLComputePipelineState> pipeline = nil;
-
-                bool use_vec_kernel = false;
+                GGML_ASSERT(ne01 < 65536);
 
                 // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
-                if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
-                    switch (src1->type) {
-                        case GGML_TYPE_F16:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_BF16:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_Q4_0:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_Q4_1:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_Q5_0:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_Q5_1:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        case GGML_TYPE_Q8_0:
-                            {
-                                if (ne00 == 192 && ne20 == 128) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
-                                } else if (ne00 == 576 && ne20 == 512) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
-                                } else {
-                                    switch (ne00) {
-                                        case 40:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
-                                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
-                                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
-                                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
-                                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
-                                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
-                                        case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
-                                        case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
-                                        default:
-                                                  {
-                                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                                      GGML_ABORT("add template specialization for this size");
-                                                  }
-                                    }
-                                }
-                            } break;
-                        default:
-                            {
-                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                GGML_LOG_ERROR("add template specialization for this type\n");
-                                GGML_ABORT("add template specialization for this type");
-                            }
-                    }
-                } else {
-                    use_vec_kernel = true;
-
-                    switch (ne00) {
-                        case 64:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        case 96:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        case 128:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        case 192:
-                            {
-                                if (ne20 == 128) {
-                                    switch (src1->type) {
-                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
-                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
-                                        default:
-                                            {
-                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                                GGML_LOG_ERROR("add template specialization for this type\n");
-                                                GGML_ABORT("add template specialization for this type");
-                                            }
-                                    }
-                                } else {
-                                    switch (src1->type) {
-                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
-                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
-                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
-                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
-                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
-                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
-                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
-                                        default:
-                                            {
-                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                                GGML_LOG_ERROR("add template specialization for this type\n");
-                                                GGML_ABORT("add template specialization for this type");
-                                            }
-                                    }
-                                }
-                            } break;
-                        case 256:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        case 576:
-                            {
-                                if (ne20 == 512) {
-                                    switch (src1->type) {
-                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
-                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
-                                        default:
-                                            {
-                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                                GGML_LOG_ERROR("add template specialization for this type\n");
-                                                GGML_ABORT("add template specialization for this type");
-                                            }
-                                    }
-                                } else {
-                                    GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
-                                    GGML_LOG_ERROR("add template specialization for this size\n");
-                                    GGML_ABORT("add template specialization for this size");
-                                }
-                            } break;
-                        default:
-                            {
-                                GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                GGML_LOG_ERROR("add template specialization for this size\n");
-                                GGML_ABORT("add template specialization for this size");
-                            }
-                    }
-                }
-
-                ggml_metal_kargs_flash_attn_ext args = {
-                    /*.ne01          =*/ ne01,
-                    /*.ne02          =*/ ne02,
-                    /*.ne03          =*/ ne03,
-                    /*.nb01          =*/ nb01,
-                    /*.nb02          =*/ nb02,
-                    /*.nb03          =*/ nb03,
-                    /*.ne11          =*/ ne11,
-                    /*.ne_12_2       =*/ ne12,
-                    /*.ne_12_3       =*/ ne13,
-                    /*.nb11          =*/ nb11,
-                    /*.nb12          =*/ nb12,
-                    /*.nb13          =*/ nb13,
-                    /*.nb21          =*/ nb21,
-                    /*.nb22          =*/ nb22,
-                    /*.nb23          =*/ nb23,
-                    /*.ne32          =*/ ne32,
-                    /*.ne33          =*/ ne33,
-                    /*.nb31          =*/ nb31,
-                    /*.nb32          =*/ nb32,
-                    /*.nb33          =*/ nb33,
-                    /*.ne1           =*/ ne1,
-                    /*.ne2           =*/ ne2,
-                    /*.ne3           =*/ ne3,
-                    /*.scale         =*/ scale,
-                    /*.max_bias      =*/ max_bias,
-                    /*.m0            =*/ m0,
-                    /*.m1            =*/ m1,
-                    /*.n_head_log2   =*/ n_head_log2,
-                    /*.logit_softcap =*/ logit_softcap,
-                };
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
-                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
-                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
-                [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
-                if (id_src3) {
-                    [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
-                } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
-                }
-                if (id_src4) {
-                    [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
-                } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
-                }
-
-                if (!use_vec_kernel) {
+                if (ne01 >= 20 || (ne00 % 32 != 0)) {
                     // half8x8 kernel
                     const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
-                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+                    const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
 
                     GGML_ASSERT(nqptg <= 32);
                     GGML_ASSERT(nqptg  % 8  == 0);
@@ -5532,34 +5167,90 @@ static int ggml_metal_encode_node(
 
                     const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
 
-                    // 2*(2*ncpsg + nqptg)*(nsg)
-                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+                    // 2*(2*ncpsg)
+                    // ncpsg soft_max values + ncpsg mask values
                     //
                     // 16*32*(nsg)
                     // the shared memory needed for the simdgroups to load the KV cache
                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
-
-                    int64_t nsgmax = 2;
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
 
-                    while (true) {
-                        const size_t smem = FATTN_SMEM(nsgmax);
-                        if (smem > device.maxThreadgroupMemoryLength/2) {
-                            break;
-                        }
-                        nsgmax *= 2;
-                    }
-                    nsgmax /= 2;
+                    //int64_t nsgmax = 4;
+                    //
+                    //if (is_q) {
+                    //    nsgmax = 2;
+                    //    while (true) {
+                    //        const size_t smem = FATTN_SMEM(nsgmax);
+                    //        if (smem > device.maxThreadgroupMemoryLength/2) {
+                    //            break;
+                    //        }
+                    //        nsgmax *= 2;
+                    //    }
+                    //    nsgmax /= 2;
+                    //}
 
                     // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+                    //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+                    int32_t nsg = 4;
 
                     const size_t smem = FATTN_SMEM(nsg);
 
+                    ggml_metal_kargs_flash_attn_ext args = {
+                        /*.ne01          =*/ ne01,
+                        /*.ne02          =*/ ne02,
+                        /*.ne03          =*/ ne03,
+                        /*.nb01          =*/ nb01,
+                        /*.nb02          =*/ nb02,
+                        /*.nb03          =*/ nb03,
+                        /*.ne11          =*/ ne11,
+                        /*.ne_12_2       =*/ ne12,
+                        /*.ne_12_3       =*/ ne13,
+                        /*.ns10          =*/ nb11/nb10,
+                        /*.nb11          =*/ nb11,
+                        /*.nb12          =*/ nb12,
+                        /*.nb13          =*/ nb13,
+                        /*.ns20          =*/ nb21/nb20,
+                        /*.nb21          =*/ nb21,
+                        /*.nb22          =*/ nb22,
+                        /*.nb23          =*/ nb23,
+                        /*.ne32          =*/ ne32,
+                        /*.ne33          =*/ ne33,
+                        /*.nb31          =*/ nb31,
+                        /*.nb32          =*/ nb32,
+                        /*.nb33          =*/ nb33,
+                        /*.ne1           =*/ ne1,
+                        /*.ne2           =*/ ne2,
+                        /*.ne3           =*/ ne3,
+                        /*.scale         =*/ scale,
+                        /*.max_bias      =*/ max_bias,
+                        /*.m0            =*/ m0,
+                        /*.m1            =*/ m1,
+                        /*.n_head_log2   =*/ n_head_log2,
+                        /*.logit_softcap =*/ logit_softcap,
+                    };
+
+                    id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg);
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
+                    if (id_src3) {
+                        [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
+                    } else {
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+                    }
+                    if (id_src4) {
+                        [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+                    } else {
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+                    }
+
                     [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
 
-                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
+                    //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12);
                     GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
                     [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
@@ -5568,7 +5259,7 @@ static int ggml_metal_encode_node(
                     // half4x4 kernel
                     const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
                     const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-                    const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
+                    const int64_t nkpsg = 1*ncpsg;
 
                     GGML_ASSERT(nqptg <= 32);
                     GGML_ASSERT(nqptg  % 1  == 0);
@@ -5581,8 +5272,7 @@ static int ggml_metal_encode_node(
                     // ne20*(nsg)
                     // each simdgroup has a full f32 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
-//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
                     while (true) {
@@ -5596,7 +5286,8 @@ static int ggml_metal_encode_node(
                     nsgmax /= 2;
 
                     // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+                    //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
 
                     int64_t nsg = 1;
                     while (nsg <= nsgt) {
@@ -5606,28 +5297,86 @@ static int ggml_metal_encode_node(
 
                     // workgroups
                     // each workgroup handles nsg*nkpsg cache values
-                    uint16_t nwg = 1;
-                    if (4*nsg*nkpsg >= ne11) {
-                        const size_t smem = FATTN_SMEM(nsg);
+                    int32_t nwg = 1;
+                    if (false) {
+                        // for small KV caches, we could launch a single workgroup and write the results directly to dst/
+                        // however, this does not lead to significant improvement, so disabled
+                        nwg = 1;
+                        nsg = 4;
+                    } else {
+                        nwg = 32;
+                        nsg = 1;
+                        while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+                            nsg *= 2;
+                        }
+                    }
 
-                        //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
-                        GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+                    ggml_metal_kargs_flash_attn_ext_vec args = {
+                        /*.ne01          =*/ ne01,
+                        /*.ne02          =*/ ne02,
+                        /*.ne03          =*/ ne03,
+                        /*.nb01          =*/ nb01,
+                        /*.nb02          =*/ nb02,
+                        /*.nb03          =*/ nb03,
+                        /*.ne11          =*/ ne11,
+                        /*.ne_12_2       =*/ ne12,
+                        /*.ne_12_3       =*/ ne13,
+                        /*.ns10          =*/ nb11/nb10,
+                        /*.nb11          =*/ nb11,
+                        /*.nb12          =*/ nb12,
+                        /*.nb13          =*/ nb13,
+                        /*.ns20          =*/ nb21/nb20,
+                        /*.nb21          =*/ nb21,
+                        /*.nb22          =*/ nb22,
+                        /*.nb23          =*/ nb23,
+                        /*.ne32          =*/ ne32,
+                        /*.ne33          =*/ ne33,
+                        /*.nb31          =*/ nb31,
+                        /*.nb32          =*/ nb32,
+                        /*.nb33          =*/ nb33,
+                        /*.ne1           =*/ ne1,
+                        /*.ne2           =*/ ne2,
+                        /*.ne3           =*/ ne3,
+                        /*.scale         =*/ scale,
+                        /*.max_bias      =*/ max_bias,
+                        /*.m0            =*/ m0,
+                        /*.m1            =*/ m1,
+                        /*.n_head_log2   =*/ n_head_log2,
+                        /*.logit_softcap =*/ logit_softcap,
+                    };
 
-                        // using 1 workgroup -> write the result directly into dst
-                        [encoder setBuffer:id_dst offset:offs_dst      atIndex:6];
-                        [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+                    id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
 
-                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
+                    if (id_src3) {
+                        [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
                     } else {
-                        nwg = 32;
-                        nsg = MIN(4, nsg);
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+                    }
+                    if (id_src4) {
+                        [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+                    } else {
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+                    }
 
-                        const size_t smem = FATTN_SMEM(nsg);
+                    const size_t smem = FATTN_SMEM(nsg);
 
-                        //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
-                        GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+                    if (nwg == 1) {
+                        // using 1 workgroup -> write the result directly into dst
+                        [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
 
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    } else {
                         // sanity checks
                         GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
                         GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
@@ -5647,20 +5396,18 @@ static int ggml_metal_encode_node(
                         //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
                         //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
 
-                        [encoder setBuffer:h_tmp  offset:0             atIndex:6];
-                        [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+                        [encoder setBuffer:h_tmp offset:0 atIndex:6];
 
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
 
                         // reduce the results from the workgroups
                         {
-                            ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+                            ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
                                 nrows,
-                                ne20,
                             };
 
-                            id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+                            id<MTLComputePipelineState> pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg);
 
                             [encoder setComputePipelineState:pipeline0];
                             [encoder setBytes:&args0   length:sizeof(args0) atIndex:0];
@@ -5668,7 +5415,7 @@ static int ggml_metal_encode_node(
                             [encoder setBuffer:id_dst  offset:offs_dst      atIndex:2];
 
                             //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
-                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)];
                         }
                     }
 #undef FATTN_SMEM
index 4dc762bf1af4dc31c138c596d0259efbadf938c6..77be3c5c9d8be69625ac6d1d02fb35cf9da73b88 100644 (file)
@@ -15,6 +15,10 @@ using namespace metal;
 #define MIN(x, y) ((x) < (y) ? (x) : (y))
 #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
 
+#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
+
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 
 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
@@ -2755,7 +2759,47 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
-template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
+template<short NR0, short NW>
+static inline void helper_mv_reduce_and_write(
+        device float * dst_f32,
+        float sumf[NR0],
+        const int r0,
+        const int ne01,
+        ushort tiisg,
+        ushort sgitg,
+        threadgroup char * shmem) {
+    threadgroup float * shmem_f32[NR0];
+
+    for (short row = 0; row < NR0; ++row) {
+        shmem_f32[row] = (threadgroup float *) shmem + NW*row;
+
+        if (sgitg == 0) {
+            shmem_f32[row][tiisg] = 0.0f;
+        }
+
+        sumf[row] = simd_sum(sumf[row]);
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (short row = 0; row < NR0; ++row) {
+        if (tiisg == 0) {
+            shmem_f32[row][sgitg] = sumf[row];
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
+        float tot = simd_sum(shmem_f32[row][tiisg]);
+
+        if (tiisg == 0 && sgitg == 0) {
+            dst_f32[r0 + row] = tot;
+        }
+    }
+}
+
+template<typename block_q_type, short NR0, short NSG, short NW, typename args_t>
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -2765,45 +2809,51 @@ void mul_vec_q_n_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-    const int nb = args.ne00/QK4_0;
+    constexpr short NQ = 16;
 
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
+    const int nb = args.ne00/QK4_0;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int r0 = (tgpig.x*NSG + sgitg)*NR0;
+  //const int r0 =  tgpig.x*NR0;
+    const int r1 =  tgpig.y;
+    const int im =  tgpig.z;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
   //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q_type * ax[nr0];
-    for (int row = 0; row < nr0; ++row) {
-        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    device const block_q_type * ax[NR0];
+    FOR_UNROLL (int row = 0; row < NR0; ++row) {
+        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
     }
 
-    float yl[16]; // src1 vector cache
-    float sumf[nr0] = {0.f};
+    float sumf[NR0] = {0.f};
 
-    const short ix = (tiisg/2);
-    const short il = (tiisg%2)*8;
+    const short ix = (tiisg/(NW/NQ));
+    const short il = (tiisg%(NW/NQ))*8;
 
-    device const float * yb = y + ix*QK4_0 + il;
+    //const int ib0 = sgitg*NQ + ix;
+    const int ib0 = ix;
+
+    float yl[16]; // src1 vector cache
+
+    //device const float * yb = y + ix*QK4_0 + il;
+    device const float * yb = y + ib0*QK4_0 + il;
 
     // each thread in a SIMD group deals with half a block.
-    for (int ib = ix; ib < nb; ib += nw/2) {
+    //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+    for (int ib = ib0; ib < nb; ib += NQ) {
         float sumy[2] = { 0.f, 0.f };
 
-#pragma unroll
-        for (short i = 0; i < 8; i += 2) {
+        FOR_UNROLL (short i = 0; i < 8; i += 2) {
             sumy[0]  += yb[i +  0] + yb[i +  1];
             yl[i + 0] = yb[i +  0];
             yl[i + 1] = yb[i +  1]/256.f;
@@ -2813,21 +2863,23 @@ void mul_vec_q_n_f32_impl(
             yl[i + 9] = yb[i + 17]/4096.f;
         }
 
-#pragma unroll
-        for (short row = 0; row < nr0; row++) {
+        FOR_UNROLL (short row = 0; row < NR0; row++) {
             sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
         yb += QK4_0 * 16;
+        //yb += NSG*NQ*QK4_0;
     }
 
     device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
 
-    for (int row = 0; row < nr0; ++row) {
+    //helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+
+    for (int row = 0; row < NR0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
-        if (tiisg == 0 && first_row + row < args.ne01) {
-            dst_f32[first_row + row] = tot;
+        if (tiisg == 0 && r0 + row < args.ne01) {
+            dst_f32[r0 + row] = tot;
         }
     }
 }
@@ -2837,10 +2889,11 @@ kernel void kernel_mul_mv_q4_0_f32(
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -2848,10 +2901,11 @@ kernel void kernel_mul_mv_q4_1_f32(
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -2859,10 +2913,11 @@ kernel void kernel_mul_mv_q5_0_f32(
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -2870,15 +2925,14 @@ kernel void kernel_mul_mv_q5_1_f32(
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-#define NB_Q8_0 8
-
-template<int nr0, int nsg, int nw, typename args_t>
+template<short NR0, short NSG, short NW, typename args_t>
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -2888,66 +2942,65 @@ void kernel_mul_mv_q8_0_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    constexpr short NQ = 8;
+
     const int nb = args.ne00/QK8_0;
 
-    const int r0 = tgpig.x;
+    const int r0 = tgpig.x*NR0;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
-
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
   //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
     device const float      * y = (device const float      *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q8_0 * ax[nr0];
-    for (int row = 0; row < nr0; ++row) {
-        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    device const block_q8_0 * ax[NR0];
+    FOR_UNROLL (short row = 0; row < NR0; ++row) {
+        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
     }
 
-    float yl[NB_Q8_0];
-    float sumf[nr0] = { 0.f };
+    float sumf[NR0] = { 0.f };
+
+    const short ix = tiisg/(NW/NQ);
+    const short il = tiisg%(NW/NQ);
+
+    const int ib0 = sgitg*NQ + ix;
 
-    const short ix = tiisg/4;
-    const short il = tiisg%4;
+    float yl[NQ];
 
-    device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+    device const float * yb = y + ib0*QK8_0 + il*NQ;
 
-    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
-    for (int ib = ix; ib < nb; ib += nw/4) {
-        for (short i = 0; i < NB_Q8_0; ++i) {
+    // each thread in a SIMD group deals with NQ quants at a time
+    for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+        for (short i = 0; i < NQ; ++i) {
             yl[i] = yb[i];
         }
 
-        for (short row = 0; row < nr0; row++) {
-            device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+        for (short row = 0; row < NR0; row++) {
+            device const int8_t * qs = ax[row][ib].qs + il*NQ;
+
             float sumq = 0.f;
-            for (short iq = 0; iq < NB_Q8_0; ++iq) {
-                sumq += qs[iq] * yl[iq];
+            FOR_UNROLL (short i = 0; i < NQ; ++i) {
+                sumq += qs[i] * yl[i];
             }
+
             sumf[row] += sumq*ax[row][ib].d;
         }
 
-        yb += nw*NB_Q8_0;
+        yb += NSG*NQ*QK8_0;
     }
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr0; ++row) {
-        const float tot = simd_sum(sumf[row]);
-
-        if (tiisg == 0 && first_row + row < args.ne01) {
-            dst_f32[first_row + row] = tot;
-        }
-    }
+    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
 [[host_name("kernel_mul_mv_q8_0_f32")]]
@@ -2956,10 +3009,11 @@ kernel void kernel_mul_mv_q8_0_f32(
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
@@ -4197,6 +4251,19 @@ kernel void kernel_leaky_relu_f32(
     dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
 }
 
+constant bool FC_flash_attn_ext_has_mask  [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
+constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
+constant bool FC_flash_attn_ext_has_bias  [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
+constant bool FC_flash_attn_ext_has_scap  [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
+
+//constant float FC_flash_attn_ext_scale         [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
+//constant float FC_flash_attn_ext_max_bias      [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
+//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
+
+constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
+constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
+constant int32_t FC_flash_attn_ext_nsg  [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
+
 // ref: https://arxiv.org/pdf/2307.08691.pdf
 template<
     typename q_t,     // query types in shared memory
@@ -4211,6 +4278,7 @@ template<
     typename qk_t,    // Q*K types
     typename qk8x8_t,
     typename s_t,     // soft-max types
+    typename s2_t,
     typename s8x8_t,
     typename o_t,     // attention accumulation types
     typename o4_t,
@@ -4221,12 +4289,12 @@ template<
     typename vd4x4_t, // value type in device memory
     short nl_v,
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short DK,        // K head size
-    short DV,        // V head size
-    short Q  = 8,    // queries per threadgroup
-    short KV = 8,    // key/value processed per each simdgroup
-    short C  = 32>   // cache items per threadgroup
-kernel void kernel_flash_attn_ext(
+    short DK,         // K head size
+    short DV,         // V head size
+    short Q,          // queries per threadgroup
+    short C,          // cache items per threadgroup
+    short NSG>        // number of simd groups
+void kernel_flash_attn_ext_impl(
         constant ggml_metal_kargs_flash_attn_ext & args,
         device const char * q,
         device const char * k,
@@ -4234,46 +4302,85 @@ kernel void kernel_flash_attn_ext(
         device const char * mask,
         device const char * sinks,
         device       char * dst,
-        threadgroup  half * shmem_f16 [[threadgroup(0)]],
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3   ntg[[threads_per_threadgroup]],
-        ushort  tiisg[[thread_index_in_simdgroup]],
-        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-    const short nsg = ntg.y; // number of simdgroups
-
-    const int iq3 = tgpig[2];
-    const int iq2 = tgpig[1];
-    const int iq1 = tgpig[0]*Q;
+        threadgroup  half * shmem_f16,
+        uint3   tgpig,
+        ushort  tiisg,
+        ushort  sgitg) {
+    const ushort iq3 = tgpig[2];
+    const ushort iq2 = tgpig[1];
+    const ushort iq1 = tgpig[0]*Q;
+
+#define NS10 (FC_flash_attn_ext_ns10)
+#define NS20 (FC_flash_attn_ext_ns20)
+
+    // note: I had some concerns that using this instead of the ugly macros above was affecting performance
+    //       need to re-check carefully and if no regressions are observerd - remove the macros
+    //       the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
+    //         is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
+    //const short NS10 = FC_flash_attn_ext_ns10;
+    //const short NS20 = FC_flash_attn_ext_ns20;
+
+    constexpr short KV   = 8;
 
     constexpr short DK4  = DK/4;
     constexpr short DK8  = DK/8;
     constexpr short DK16 = DK/16;
     constexpr short DV4  = DV/4;
-    constexpr short DV8  = DV/8;
+  //constexpr short DV8  = DV/8;
     constexpr short DV16 = DV/16;
 
+    constexpr short PV   = PAD2(DV, 64);
+    constexpr short PV4  = PV/4;
+    constexpr short PV8  = PV/8;
+  //constexpr short PV16 = PV/16;
+
     constexpr short NW  = N_SIMDWIDTH;
-    constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+    constexpr short NQ  = Q/NSG;
+    constexpr short SH  = 2*C; // shared memory per simdgroup (s_t == float)
+
+    constexpr short TS = 2*SH;
+    constexpr short T  = DK + 2*PV; // shared memory size per query in (half)
+
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 + 0*T); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
+
+    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
+    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
 
-    const short TS = nsg*SH;      // shared memory size per query in (s_t == float)
-    const short T  = 2*DK + 2*TS; // shared memory size per query in (half)
+    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
+    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
+    // mask storage in shared mem
+    threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
 
-    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
-    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+    // per-query mask pointers
+    device const half2 * pm2[NQ];
 
-    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
-    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+        const short j = jj*NSG + sgitg;
 
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o8x8_t lo[DV8];
+        pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
+    }
+
+    {
+        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        k += ikv2*args.nb12 + ikv3*args.nb13;
+        v += ikv2*args.nb22 + ikv3*args.nb23;
+    }
 
     // load heads from Q to shared memory
-    for (short j = sgitg; j < Q; j += nsg) {
-        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+        const short j = jj*NSG + sgitg;
+
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
 
         for (short i = tiisg; i < DK4; i += NW) {
             if (iq1 + j < args.ne01) {
@@ -4284,43 +4391,30 @@ kernel void kernel_flash_attn_ext(
         }
     }
 
-    // zero out lo
-    for (short i = 0; i < DV8; ++i) {
-        lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
-    }
+    // zero out
+    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+        const short j = jj*NSG + sgitg;
+
+        for (short i = tiisg; i < DV4; i += NW) {
+            so4[j*PV4 + i] = 0;
+        }
 
-    // zero out shared memory SH
-    for (short j = 0; j < Q; ++j) {
         for (short i = tiisg; i < SH; i += NW) {
-            ss[j*TS + i] = 0.0f;
+            ss[j*SH + i] = 0.0f;
         }
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    {
-        float S[Q] = { [0 ... Q-1] = 0.0f };
-        float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
-
-        // thread indices inside the simdgroup
-        // TODO: see if we can utilize quad-group functions for better performance
-        //       https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
-        const short tx = tiisg%4;
-        const short ty = tiisg/4;
-
-        // broadcast kv
-        //const short rk2 = args.ne02/args.ne12;
-        //const short rk3 = args.ne03/args.ne13;
-
-        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
-        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+    float S[NQ] = { [0 ... NQ-1] = 0.0f };
 
-        const bool has_mask = mask != q;
+    {
+        float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
 
         float slope = 1.0f;
 
         // ALiBi
-        if (args.max_bias > 0.0f) {
+        if (FC_flash_attn_ext_has_bias) {
             const short h = iq2;
 
             const float base = h < args.n_head_log2 ? args.m0 : args.m1;
@@ -4331,177 +4425,277 @@ kernel void kernel_flash_attn_ext(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
-            const int ic = ic0 + C*sgitg;
-            if (ic >= args.ne11) {
-                break;
-            }
-
-            if (has_mask) {
-                // used to detect blocks full of -INF
-                float smax = -INFINITY;
+        for (int ic = 0; ic < args.ne11; ic += C) {
+            // read the mask into shared mem
+            if (FC_flash_attn_ext_has_mask) {
+                FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                    const short j = jj*NSG + sgitg;
+
+                    sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                    pm2[jj] += NW;
+                }
 
-                // load the mask in shared memory
-                #pragma unroll(Q)
-                for (short j = 0; j < Q; ++j) {
-                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
+                threadgroup_barrier(mem_flags::mem_threadgroup);
 
-                    const float m = pm[ic + tiisg];
+                // used to detect blocks full of -INF
+                // skip only when the entire threadgroup is masked
+                half2 smax2(-MAXHALF/2, -MAXHALF/2);
 
-                    ss[j*TS + C + tiisg] = m;
-                    smax = max(smax, m);
+                FOR_UNROLL (short j = 0; j < Q; ++j) {
+                    smax2 = max(smax2, sm2[j*SH + tiisg]);
                 }
 
-                smax = simd_max(smax);
+                smax2 = simd_max(smax2);
+
+                if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
+                    // this barrier is important
+                    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-                if (smax == -INFINITY) {
                     continue;
                 }
             }
 
             // Q*K^T
-            {
-                for (short cc = 0; cc < C/8; ++cc) {
+            // this is compile-time check, so it does not have runtime overhead
+            if (is_same<kd4x4_t, k4x4_t>::value) {
+                // we can read directly from global memory
+                device      const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
+                threadgroup const q_t * pq = sq;
+                threadgroup       s_t * ps = ss;
+
+                pk += sgitg*(8*NS10);
+                ps += sgitg*(8*1);
+
+                static_assert((C/8) % NSG == 0, "");
+
+                constexpr short NC = (C/8)/NSG;
+
+                // TODO: not good to unroll for large contexts - not sure why?
+                for (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 
-                    // this is compile-time check, so it does not have runtime overhead
-                    if (is_same<kd4x4_t, k4x4_t>::value) {
-                        // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+                    if (DK8 % 16 != 0) {
+                        k8x8_t mk;
+                        q8x8_t mq;
+
+                        FOR_UNROLL (short i = 0; i < DK8; ++i) {
+                            simdgroup_barrier(mem_flags::mem_none);
+
+                            simdgroup_load(mk, pk, NS10, 0, true);
+                            simdgroup_load(mq, pq, DK);
 
-                        #pragma unroll(DK8)
-                        for (short i = 0; i < DK8; ++i) {
-                            k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+                            simdgroup_barrier(mem_flags::mem_none);
 
-                            q8x8_t mq;
-                            simdgroup_load(mq, sq + i*8, DK);
                             simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+                            pk += 8;
+                            pq += 8;
                         }
                     } else {
-                        for (short ii = 0; ii < DK16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+                        k8x8_t mk[2];
+                        q8x8_t mq[2];
 
-                            if (DK16%4 == 0) {
-                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
-                                {
-                                    k4x4_t tmp;
-                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
-                                    sk4x4[4*ty + tx] = tmp;
-                                }
+                        FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+                            simdgroup_barrier(mem_flags::mem_none);
 
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
+                            simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
+                            simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
 
-                                #pragma unroll(4)
-                                for (short k = 0; k < 4; ++k) {
-                                    k8x8_t mk;
-                                    q8x8_t mq;
+                            simdgroup_load(mq[0], pq + 0*8, DK);
+                            simdgroup_load(mq[1], pq + 1*8, DK);
 
-                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
-                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+                            simdgroup_barrier(mem_flags::mem_none);
 
-                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
-                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
-                                }
-                            } else {
-                                if (ii + tx < DK16) {
-                                    k4x4_t tmp;
-                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
-                                    sk4x4[4*ty + tx] = tmp;
-                                }
+                            simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
+                            simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
 
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
+                            pk += 16;
+                            pq += 16;
+                        }
+                    }
 
-                                for (short k = 0; k < 4 && ii + k < DK16; ++k) {
-                                    k8x8_t mk;
-                                    q8x8_t mq;
+                    simdgroup_store(mqk, ps, SH, 0, false);
 
-                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
-                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+                    pk += 8*(NSG*NS10 - DK8);
+                    pq += 8*(NSG*0    - DK8);
+                    ps += 8*(NSG);
+                }
+            } else {
+                // TODO: this is the quantized K cache branch - not optimized yet
+                for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
+                    const short cc = ccc*NSG + sgitg;
 
-                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
-                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
-                                }
+                    const short tx = tiisg%4;
+                    const short ty = tiisg/4;
+
+                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
+
+                    for (short ii = 0; ii < DK16; ii += 4) {
+                        device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
+
+                        if (DK16%4 == 0) {
+                            // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+                            {
+                                k4x4_t tmp;
+                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                sk4x4[4*ty + tx] = tmp;
+                            }
+
+                            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                            FOR_UNROLL (short k = 0; k < 4; ++k) {
+                                k8x8_t mk;
+                                q8x8_t mq;
+
+                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+                            }
+                        } else {
+                            if (ii + tx < DK16) {
+                                k4x4_t tmp;
+                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                sk4x4[4*ty + tx] = tmp;
+                            }
+
+                            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                            for (short k = 0; k < 4 && ii + k < DK16; ++k) {
+                                k8x8_t mk;
+                                q8x8_t mq;
+
+                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                             }
                         }
                     }
 
-                    // cast qk_t -> s_t
-                    //s8x8_t mqks(1.0f);
-                    //simdgroup_multiply(mqks, mqk, mqks);
-                    //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
-
-                    simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+                    simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
                 }
             }
 
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+
             // online softmax
-            {
-                for (ushort j = 0; j < Q; ++j) {
-                    const float m = M[j];
+            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                const short j = jj*NSG + sgitg;
 
-                    // scale and apply the logitcap / mask
-                    float s = ss[j*TS + tiisg]*args.scale;
+                const float m = M[jj];
 
-                    if (args.logit_softcap != 0.0f) {
-                        s = args.logit_softcap*precise::tanh(s);
-                    }
+                // scale and apply the logitcap / mask
+                float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
+
+                if (FC_flash_attn_ext_has_scap) {
+                    s2 = args.logit_softcap*precise::tanh(s2);
+                }
 
-                    // mqk = mqk + mask*slope
-                    s += slope*ss[j*TS + C + tiisg];
+                // mqk = mqk + slope*mask
+                if (FC_flash_attn_ext_has_bias) {
+                    s2 += s2_t(sm2[j*SH + tiisg])*slope;
+                } else {
+                    s2 += s2_t(sm2[j*SH + tiisg]);
+                }
 
-                    M[j] = simd_max(max(M[j], s));
+                M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
 
-                    const float ms = exp(m - M[j]);
-                    const float vs = exp(s - M[j]);
+                const float  ms  = exp(m  - M[jj]);
+                const float2 vs2 = exp(s2 - M[jj]);
 
-                    S[j] = S[j]*ms + simd_sum(vs);
+                S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
 
-                    // the P matrix from the paper (Q rows, C columns)
-                    ss[j*TS + tiisg] = vs;
+                // the P matrix from the paper (Q rows, C columns)
+                ss2[j*SH/2 + tiisg] = vs2;
 
-                    // create a QxQ diagonal matrix for rescaling the output
-                    if (tiisg == j) {
-                        ss[j*TS + 2*C + j] = ms;
+                if (DV4 % NW == 0) {
+                    FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+                        const short i = ii*NW + tiisg;
+
+                        so4[j*PV4 + i] *= ms;
+                    }
+                } else {
+                    for (short i = tiisg; i < DV4; i += NW) {
+                        so4[j*PV4 + i] *= ms;
                     }
                 }
             }
 
-            // O = diag(ms)*O
-            {
-                s8x8_t ms;
-                simdgroup_load(ms, ss + 2*C, TS, 0, false);
-
-                #pragma unroll(DV8)
-                for (short i = 0; i < DV8; ++i) {
-                    simdgroup_multiply(lo[i], ms, lo[i]);
-                }
-            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
 
             // O = O + (Q*K^T)*V
             {
-                for (short cc = 0; cc < C/8; ++cc) {
-                    s8x8_t vs;
-                    simdgroup_load(vs, ss + 8*cc, TS, 0, false);
+                // we can read directly from global memory
+                if (is_same<vd4x4_t, v4x4_t>::value) {
+                    static_assert(PV8 % NSG == 0, "");
+
+                    constexpr short NO = PV8/NSG;
 
-                    if (is_same<vd4x4_t, v4x4_t>::value) {
-                        // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+                    o8x8_t lo[NO];
 
-                        #pragma unroll(DV8)
-                        for (short i = 0; i < DV8; ++i) {
-                            v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
+                    {
+                        auto sot = so + 8*sgitg;
 
-                            simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
+                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+                            simdgroup_load(lo[ii], sot, PV, 0, false);
+
+                            sot += 8*NSG;
                         }
-                    } else {
-                        for (short ii = 0; ii < DV16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+                    }
+
+                    {
+                        auto sst = ss;
+
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
+
+                        pv += 8*sgitg;
+
+                        FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
+                            s8x8_t vs;
+                            simdgroup_load(vs, sst, SH, 0, false);
+
+                            FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+                                v8x8_t mv;
+
+                                simdgroup_load(mv, pv, NS20, 0, false);
+                                simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
+
+                                pv += 8*NSG;
+                            }
+
+                            pv  += 8*(NS20 - NO*NSG);
+                            sst += 8;
+                        }
+                    }
+
+                    {
+                        auto sot = so + 8*sgitg;
+
+                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+                            simdgroup_store(lo[ii], sot, PV, 0, false);
+
+                            sot += 8*NSG;
+                        }
+                    }
+                } else {
+                    // TODO: this is the quantized V cache branch - not optimized yet
+
+                    const short tx = tiisg%4;
+                    const short ty = tiisg/4;
+
+                    for (short cc = 0; cc < C/8; ++cc) {
+                        s8x8_t vs;
+                        simdgroup_load(vs, ss + 8*cc, SH, 0, false);
+
+                        for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
 
                             if (DV16%4 == 0) {
                                 // no need for bound checks
@@ -4513,15 +4707,20 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                #pragma unroll(4)
-                                for (short k = 0; k < 4; ++k) {
-                                    v8x8_t mv;
+                                FOR_UNROLL (short k = 0; k < 4; ++k) {
+                                    v8x8_t mv[2];
+                                    o8x8_t lo[2];
 
-                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
 
-                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
+
+                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
                                 }
                             } else {
                                 if (ii + tx < DV16) {
@@ -4533,243 +4732,249 @@ kernel void kernel_flash_attn_ext(
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
                                 for (short k = 0; k < 4 && ii + k < DV16; ++k) {
-                                    v8x8_t mv;
+                                    v8x8_t mv[2];
+                                    o8x8_t lo[2];
+
+                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
 
-                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
 
-                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
                                 }
                             }
                         }
                     }
                 }
             }
-        }
 
-        if (sinks != q && sgitg == 0) {
-            for (ushort j = 0; j < Q; ++j) {
-                const float m = M[j];
-                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
 
-                M[j] = simd_max(max(M[j], s));
+        if (FC_flash_attn_ext_has_sinks) {
+            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                const short j = jj*NSG + sgitg;
 
-                const float ms = exp(m - M[j]);
-                const float vs = exp(s - M[j]);
+                const float m = M[jj];
+                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
 
-                S[j] = S[j]*ms + simd_sum(vs);
+                M[jj] = simd_max(max(M[jj], s));
 
-                if (tiisg == j) {
-                    ss[j*TS + 2*C + j] = ms;
-                }
-            }
+                const float ms = exp(m - M[jj]);
+                const float vs = exp(s - M[jj]);
 
-            // O = diag(ms)*O
-            {
-                s8x8_t ms;
-                simdgroup_load(ms, ss + 2*C, TS, 0, false);
+                S[jj] = S[jj]*ms + simd_sum(vs);
 
-                #pragma unroll(DV8)
-                for (short i = 0; i < DV8; ++i) {
-                    simdgroup_multiply(lo[i], ms, lo[i]);
+                for (short i = tiisg; i < DV4; i += NW) {
+                    so4[j*PV4 + i] *= ms;
                 }
             }
         }
-
-        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
-        for (short j = tiisg; j < Q; j += NW) {
-            ss[j*TS + 0] = S[j];
-            ss[j*TS + 1] = M[j];
-        }
     }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    threadgroup float  * so  = (threadgroup float  *) (shmem_f16 + 0*DK); // reuse query data for accumulation
-    threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
-
-    // store result to shared memory in F32
-    if (sgitg == 0) {
-        for (short i = 0; i < DV8; ++i) {
-            //simdgroup_store(lo[i], so + i*8, DV, 0, false);
-            simdgroup_float8x8 t(1.0f);
-            simdgroup_multiply(t, lo[i], t);
-            simdgroup_store(t, so + i*8, DV, 0, false);
+    // store to global memory
+    for (short jj = 0; jj < NQ; ++jj) {
+        const short j = jj*NSG + sgitg;
+        if (iq1 + j >= args.ne01) {
+            break;
         }
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    // reduce the warps sequentially
-    for (ushort sg = 1; sg < nsg; ++sg) {
-        if (sgitg == sg) {
-            for (short j = tiisg; j < Q; j += NW) {
-                const float S0 = ss[j*TS - 1*SH + 0];
-                const float S1 = ss[j*TS        + 0];
 
-                const float M0 = ss[j*TS - 1*SH + 1];
-                const float M1 = ss[j*TS        + 1];
-
-                const float M = max(M0, M1);
-
-                float ms0 = exp(M0 - M);
-                float ms1 = exp(M1 - M);
+        device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
 
-                const float S = S0*ms0 + S1*ms1;
+        const float scale = 1.0f/S[jj];
 
-                ss[j*TS + 0] = S;
-                ss[j*TS + 1] = M;
+        if (DV4 % NW == 0) {
+            FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+                const short i = ii*NW + tiisg;
 
-                ss[j*TS + 2*C + j - 1*SH] = ms0;
-                ss[j*TS + 2*C + j       ] = ms1;
+                dst4[i] = (float4) so4[j*PV4 + i]*scale;
             }
-
-            //simdgroup_barrier(mem_flags::mem_threadgroup);
-
-            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            {
-                s8x8_t ms0;
-                s8x8_t ms1;
-
-                simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
-                simdgroup_load(ms1, ss + 2*C,        TS, 0, false);
-
-                #pragma unroll(DV8)
-                for (short i = 0; i < DV8; ++i) {
-                    simdgroup_float8x8 t;
-
-                    simdgroup_load    (t, so + i*8, DV, 0, false);
-                    simdgroup_multiply(t, ms0, t);
-
-                    simdgroup_multiply_accumulate(t, ms1, lo[i], t);
-                    simdgroup_store(t, so + i*8, DV, 0, false);
-                }
+        } else {
+            for (short i = tiisg; i < DV4; i += NW) {
+                dst4[i] = (float4) so4[j*PV4 + i]*scale;
             }
         }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
-
-    // final rescale with 1/S and store to global memory
-    for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
-        const float S = 1.0f/sf[j*TS + 0];
-
-        device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
+#undef NS10
+#undef NS20
+}
 
-        for (short i = tiisg; i < DV4; i += NW) {
-            dst4[i] = (float4) so4[j*DV4 + i]*S;
-        }
+template<
+    typename q_t,     // query types in shared memory
+    typename q4_t,
+    typename q8x8_t,
+    typename k_t,     // key types in shared memory
+    typename k4x4_t,
+    typename k8x8_t,
+    typename v_t,     // value types in shared memory
+    typename v4x4_t,
+    typename v8x8_t,
+    typename qk_t,    // Q*K types
+    typename qk8x8_t,
+    typename s_t,     // soft-max types
+    typename s2_t,
+    typename s8x8_t,
+    typename o_t,     // attention accumulation types
+    typename o4_t,
+    typename o8x8_t,
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // value type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short DK,         // K head size
+    short DV,         // V head size
+    short Q  = 8,     // queries per threadgroup
+    short C  = 64>    // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device const char * sinks,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
+#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+    switch (FC_flash_attn_ext_nsg) {
+      // note: disabled cases to reduce library load time
+      //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
+      //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
+        case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
     }
+#undef FWD_TMPL
+#undef FWD_ARGS
 }
 
 // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
 //       template to be able to explore different combinations
 //
 #define FA_TYPES \
-    float,  float4,    simdgroup_float8x8, \
+    half,   half4,     simdgroup_half8x8,  \
     half,   half4x4,   simdgroup_half8x8,  \
     half,   half4x4,   simdgroup_half8x8,  \
     float,             simdgroup_float8x8, \
-    float,             simdgroup_float8x8, \
-    half,   half4,     simdgroup_half8x8
-    //float,  float4,    simdgroup_float8x8
+    float,  float2,    simdgroup_float8x8, \
+    float,  float4,    simdgroup_float8x8
+    //half,   half4,     simdgroup_half8x8
 
 #define FA_TYPES_BF \
     bfloat, bfloat4,   simdgroup_bfloat8x8, \
     bfloat, bfloat4x4, simdgroup_bfloat8x8, \
     bfloat, bfloat4x4, simdgroup_bfloat8x8, \
     float,             simdgroup_float8x8,  \
-    float,             simdgroup_float8x8,  \
+    float,  float2,    simdgroup_float8x8,  \
     half,   half4,     simdgroup_half8x8
     //float,  float4,    simdgroup_float8x8
 
 typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
 
-template [[host_name("kernel_flash_attn_ext_f16_h40" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  40,  40>;
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 192>;
-template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256, 256>;
-template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  576, 512>;
+template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  40,  40>;
+template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;
+template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;
+template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;
+template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 192>;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256, 256>;
+template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  576, 512>;
 
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;
 #endif
 
-template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40,  40>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40,  40>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
 
 #undef FA_TYPES
 #undef FA_TYPES_BF
 
+constant bool FC_flash_attn_ext_vec_has_mask  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
+constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
+constant bool FC_flash_attn_ext_vec_has_bias  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
+constant bool FC_flash_attn_ext_vec_has_scap  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
+
+//constant float FC_flash_attn_ext_vec_scale         [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
+//constant float FC_flash_attn_ext_vec_max_bias      [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
+//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
+
+constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
+constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
+constant int32_t FC_flash_attn_ext_vec_nsg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
+constant int32_t FC_flash_attn_ext_vec_nwg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
+
 template<
     typename q4_t,  // query types in shared memory
     typename k4_t,  // key types in shared memory
@@ -4788,63 +4993,86 @@ template<
     short DV,       // V head size
     short NE = 4,   // head elements per thread
     short Q  = 1,   // queries per threadgroup
-    short C  = 32>  // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        constant ggml_metal_kargs_flash_attn_ext & args,
+    short C  = 32,  // cache items per threadgroup
+    short NSG>      // number of simd groups
+void kernel_flash_attn_ext_vec_impl(
+        constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
         device const char * k,
         device const char * v,
         device const char * mask,
         device const char * sinks,
         device       char * dst,
-        constant uint16_t & nwg,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3   ntg[[threads_per_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
     static_assert(DK % 32 == 0, "DK must be divisible by 32");
     static_assert(DV % 32 == 0, "DV must be divisible by 32");
 
-    const short nsg = ntg.y; // number of simdgroups
-    const short iwg = tgpig[2]%nwg;
+#define NWG  (FC_flash_attn_ext_vec_nwg)
+
+#define NS10 (FC_flash_attn_ext_vec_ns10)
+#define NS20 (FC_flash_attn_ext_vec_ns20)
+
+    const short iwg = tgpig[2]%NWG;
 
-    const int iq3 = tgpig[2]/nwg;
-    const int iq2 = tgpig[1];
-    const int iq1 = tgpig[0];
+    const ushort iq3 = tgpig[2]/NWG;
+    const ushort iq2 = tgpig[1];
+    const ushort iq1 = tgpig[0];
 
     constexpr short DK4 = DK/4;
     constexpr short DV4 = DV/4;
+
+    constexpr short PK  = PAD2(DK, 128);
+    constexpr short PK4 = PK/4;
+
+    constexpr short PV  = PAD2(DV, 128);
+    constexpr short PV4 = PV/4;
+
     constexpr short NW  = N_SIMDWIDTH;
     constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
     constexpr short SH  = 4*C;   // shared memory per simdgroup
 
-    const short T = DK + nsg*SH; // shared memory size per query in (half)
+    static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
+    static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
+
+    const short T = PK + NSG*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*DK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*DK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*DK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*DK); // same as above but in s4_t
-    threadgroup float * sm  = (threadgroup float *) (shmem_f16 +   sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
-    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*DV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*PK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*PK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*PK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*PK); // same as above but in s4_t
+    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
+    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + Q*T);  // scratch buffer for the results
 
-    // store the result for all queries in local memory (the O matrix from the paper)
-    o4_t lo[DV4/NL];
+    // store the result for all queries in shared memory (the O matrix from the paper)
+    so4 += tiisg;
+
+    {
+        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        k += ikv2*args.nb12 + ikv3*args.nb13;
+        v += ikv2*args.nb22 + ikv3*args.nb23;
+    }
 
     // load heads from Q to shared memory
-    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+    device const float4 * q4 = (device const float4 *) ((device const char *) q);
 
-    for (short i = tiisg; i < DK4; i += NW) {
-        if (iq1 < args.ne01) {
+    for (short i = tiisg; i < PK4; i += NW) {
+        if (iq1 < args.ne01 && i < DK4) {
             sq4[i] = (q4_t) q4[i];
         } else {
             sq4[i] = (q4_t) 0.0f;
         }
     }
 
-    // zero out lo
+    // zero out so
     for (short i = 0; i < DV4/NL; ++i) {
-        lo[i] = (o4_t) 0.0f;
+        so4[i*NL] = (o4_t) 0.0f;
     }
 
     // zero out shared memory SH
@@ -4856,28 +5084,19 @@ kernel void kernel_flash_attn_ext_vec(
 
     {
         float S = 0.0f;
-        float M = -__FLT_MAX__/2;
+        float M = -FLT_MAX/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;
         const short ty = tiisg/NL;
 
-        // broadcast kv
-        //const short rk2 = args.ne02/args.ne12;
-        //const short rk3 = args.ne03/args.ne13;
-
-        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
-        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
-
-        const bool has_mask = mask != q;
-
         // pointer to the mask
         device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
 
         float slope = 1.0f;
 
         // ALiBi
-        if (args.max_bias > 0.0f) {
+        if (FC_flash_attn_ext_vec_has_bias) {
             const short h = iq2;
 
             const float base = h < args.n_head_log2 ? args.m0 : args.m1;
@@ -4888,13 +5107,13 @@ kernel void kernel_flash_attn_ext_vec(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
+        for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
             const int ic = ic0 + C*sgitg;
             if (ic >= args.ne11) {
                 break;
             }
 
-            if (has_mask) {
+            if (FC_flash_attn_ext_vec_has_mask) {
                 sm[tiisg] = pm[ic + tiisg];
             }
 
@@ -4905,69 +5124,81 @@ kernel void kernel_flash_attn_ext_vec(
 
             // Q*K^T
             {
-                // each simdgroup processes 1 query and NE (NW/NL) head elements
-                for (short cc = 0; cc < C/NE; ++cc) {
-                    qk_t mqk = 0.0f;
+                device      const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
+                threadgroup const q4_t * pq4 = sq4;
 
-                    device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+                pk4 += ty*NS10/4 + tx;
+                pq4 += tx;
 
-                    #pragma unroll(DK4/NL)
-                    for (short ii = 0; ii < DK4; ii += NL) {
-                        const short i = ii + tx;
+                qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
+
+                // each simdgroup processes 1 query and NE (NW/NL) cache elements
+                FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+                    if (is_same<kd4_t, k4_t>::value) {
+                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
+                            mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 +  ii*NL], (float4) pq4[ii*NL]);
+                        }
+                    } else {
+                        device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
 
                         k4_t mk;
-                        deq_k_t4(pk + i/nl_k, i%nl_k, mk);
 
-                        // note: this is less precise than the version below
-                        //mqka[0] += dot(mq[0], mk[0]);
-                        //mqka[1] += dot(mq[1], mk[1]);
-                        //mqka[2] += dot(mq[2], mk[2]);
-                        //mqka[3] += dot(mq[3], mk[3]);
+                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
+                            const short i = ii*NL + tx;
 
-                        //q4x4_t mq = sq4x4[i];
-                        //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
-                        //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
-                        //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
-                        //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+                            deq_k_t4(pk + i/nl_k, i%nl_k, mk);
 
-                        mqk += dot((float4) mk, (float4) sq4[i]);
+                            mqk[cc] += dot((float4) mk, (float4) sq4[i]);
+                        }
                     }
 
-                    static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
+                    if (NE == 1) {
+                        mqk[cc] = simd_sum(mqk[cc]);
+                    } else {
+                        // simdgroup reduce (NE = 4)
+                        // [ 0 ..  7] -> [ 0]
+                        // [ 8 .. 15] -> [ 8]
+                        // [16 .. 23] -> [16]
+                        // [24 .. 31] -> [24]
+                        if (NE <= 1) {
+                            mqk[cc] += simd_shuffle_down(mqk[cc], 16);
+                        }
+                        if (NE <= 2) {
+                            mqk[cc] += simd_shuffle_down(mqk[cc],  8);
+                        }
+                        if (NE <= 4) {
+                            mqk[cc] += simd_shuffle_down(mqk[cc],  4);
+                        }
+                        if (NE <= 8) {
+                            mqk[cc] += simd_shuffle_down(mqk[cc],  2);
+                        }
+                        if (NE <= 16) {
+                            mqk[cc] += simd_shuffle_down(mqk[cc],  1);
+                        }
 
-                    // simdgroup reduce (NE = 4)
-                    // [ 0 ..  7] -> [ 0]
-                    // [ 8 .. 15] -> [ 8]
-                    // [16 .. 23] -> [16]
-                    // [24 .. 31] -> [24]
-                    if (NE <= 1) {
-                        mqk += simd_shuffle_down(mqk, 16);
-                    }
-                    if (NE <= 2) {
-                        mqk += simd_shuffle_down(mqk,  8);
-                    }
-                    if (NE <= 4) {
-                        mqk += simd_shuffle_down(mqk,  4);
-                    }
-                    if (NE <= 8) {
-                        mqk += simd_shuffle_down(mqk,  2);
+                        // broadcast
+                        mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
                     }
-                    if (NE <= 16) {
-                        mqk += simd_shuffle_down(mqk,  1);
-                    }
-
-                    // mqk = mqk*scale + mask*slope
-                    if (tx == 0) {
-                        mqk *= args.scale;
+                }
 
-                        if (args.logit_softcap != 0.0f) {
-                            mqk = args.logit_softcap*precise::tanh(mqk);
-                        }
+                if (FC_flash_attn_ext_vec_has_mask &&
+                   !FC_flash_attn_ext_vec_has_scap &&
+                   !FC_flash_attn_ext_vec_has_bias) {
+                    ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
+                } else {
+                    mqk[tx] *= args.scale;
 
-                        mqk += sm[NE*cc + ty]*slope;
+                    if (FC_flash_attn_ext_vec_has_scap) {
+                        mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
+                    }
 
-                        ss[NE*cc + ty] = mqk;
+                    if (FC_flash_attn_ext_vec_has_bias) {
+                        mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
+                    } else {
+                        mqk[tx] += (qk_t) sm[NE*tx + ty];
                     }
+
+                    ss[NE*tx + ty] = mqk[tx];
                 }
             }
 
@@ -4989,9 +5220,10 @@ kernel void kernel_flash_attn_ext_vec(
                 ss[tiisg] = vs;
 
                 // O = diag(ms)*O
-                #pragma unroll(DV4/NL)
-                for (short ii = 0; ii < DV4; ii += NL) {
-                    lo[ii/NL] *= ms;
+                if ((DV4/NL % NW == 0) || ty == 0) {
+                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                        so4[ii*NL] *= ms;
+                    }
                 }
             }
 
@@ -4999,26 +5231,84 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O = O + (Q*K^T)*V
             {
-                //#pragma unroll(C/NE)
-                for (short cc = 0; cc < C/NE; ++cc) {
-                    device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+                o4_t lo[DV4/NL];
+                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                    lo[ii] = 0.0f;
+                }
 
-                    const s4_t ms(ss[NE*cc + ty]);
+                if (is_same<vd4_t, v4_t>::value) {
+                    device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
 
-                    #pragma unroll(DV4/NL)
-                    for (short ii = 0; ii < DV4; ii += NL) {
-                        const short i = ii + tx;
+                    pv4 += ty*NS20/4 + tx;
 
-                        v4_t mv;
-                        deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
+                    const auto sst = ss + ty;
 
-                        lo[ii/NL] += o4_t(float4(mv)*float4(ms));
+                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                            lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
+                        }
+                    }
+                } else {
+                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+                        device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
+
+                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                            const short i = ii*NL + tx;
+
+                            v4_t mv;
+                            deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
+
+                            lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
+                        }
+                    }
+                }
+
+                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                    if (NE > 1) {
+                        lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
+                        lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
+                        lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
+                        lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
+                    }
+
+                    if (NE > 2) {
+                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  8);
+                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  8);
+                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  8);
+                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  8);
+                    }
+
+                    if (NE > 4) {
+                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  4);
+                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  4);
+                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  4);
+                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  4);
+                    }
+
+                    if (NE > 8) {
+                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  2);
+                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  2);
+                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  2);
+                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  2);
+                    }
+
+                    if (NE > 16) {
+                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  1);
+                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  1);
+                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  1);
+                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  1);
+                    }
+                }
+
+                if ((DV4/NL % NW == 0) || ty == 0) {
+                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                        so4[ii*NL] += lo[ii];
                     }
                 }
             }
         }
 
-        if (sinks != q && sgitg == 0 && iwg == 0) {
+        if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
             const float m = M;
             const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
 
@@ -5029,9 +5319,10 @@ kernel void kernel_flash_attn_ext_vec(
 
             S = S*ms + simd_sum(vs);
 
-#pragma unroll(DV4/NL)
-            for (short ii = 0; ii < DV4; ii += NL) {
-                lo[ii/NL] *= ms;
+            if ((DV4/NL % NW == 0) || ty == 0) {
+                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+                    so4[ii*NL] *= ms;
+                }
             }
         }
 
@@ -5042,63 +5333,12 @@ kernel void kernel_flash_attn_ext_vec(
         }
     }
 
-    // simdgroup reduce (NE = 4)
-    // [ 0,  8, 16, 24] -> [ 0]
-    // [ 1,  9, 17, 25] -> [ 1]
-    // [ 2, 10, 18, 26] -> [ 2]
-    // [ 3, 11, 19, 27] -> [ 3]
-    // [ 4, 12, 20, 28] -> [ 4]
-    // [ 5, 13, 21, 29] -> [ 5]
-    // [ 6, 14, 22, 30] -> [ 6]
-    // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < DV4; ii += NL) {
-        if (NE > 1) {
-            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
-            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
-            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
-            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
-        }
-
-        if (NE > 2) {
-            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
-            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
-            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
-            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
-        }
-
-        if (NE > 4) {
-            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
-            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
-            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
-            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
-        }
-
-        if (NE > 8) {
-            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
-            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
-            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
-            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
-        }
-
-        if (NE > 16) {
-            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
-            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
-            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
-            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
-        }
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    // store results to shared memory
-    for (short i = tiisg; i < DV4; i += NL) {
-        sr4[i] = lo[i/NL];
-    }
+    so4 -= tiisg;
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     // parallel reduce
-    for (short r = nsg/2; r > 0; r >>= 1) {
+    for (short r = NSG/2; r > 0; r >>= 1) {
         if (sgitg < r) {
             const float S0 = ss[           0];
             const float S1 = ss[r*(SH/2) + 0];
@@ -5120,7 +5360,7 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
             for (short i = tiisg; i < DV4; i += NW) {
-                sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
+                so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
             }
         }
 
@@ -5133,21 +5373,73 @@ kernel void kernel_flash_attn_ext_vec(
         const int64_t rid   = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
 
         device float4 * dst4 = (device float4 *) dst;
-        device float  * dst1 = (device float  *) dst + nrows*DV*nwg; // the S and M are stored after the results
+        device float  * dst1 = (device float  *) dst + nrows*DV*NWG; // the S and M are stored after the results
 
-        const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
+        const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
 
         // interleave the workgroup data
         for (short i = tiisg; i < DV4; i += NW) {
-            dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+            dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
         }
 
         // store S and M
-        if (nwg > 1 && tiisg == 0) {
-            dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
-            dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
+        if (NWG > 1) {
+            if (tiisg == 0) {
+                dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
+                dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
+            }
         }
     }
+
+#undef NWG
+#undef NS10
+#undef NS20
+}
+
+template<
+    typename q4_t,  // query types in shared memory
+    typename k4_t,  // key types in shared memory
+    typename v4_t,  // value types in shared memory
+    typename qk_t,  // Q*K types
+    typename s_t,   // soft-max types
+    typename s4_t,
+    typename o4_t,  // attention accumulation types
+    typename kd4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+    typename vd4_t, // value type in device memory
+    short nl_v,
+    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+    short DK,       // K head size
+    short DV,       // V head size
+    short NE = 4,   // head elements per thread
+    short Q  = 1,   // queries per threadgroup
+    short C  = 32>  // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+        constant ggml_metal_kargs_flash_attn_ext_vec & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device const char * sinks,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
+#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+    switch (FC_flash_attn_ext_vec_nsg) {
+      // note: disabled cases to reduce library load time
+        case 1:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  1>(FWD_ARGS); break;
+        case 2:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  2>(FWD_ARGS); break;
+        case 4:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  4>(FWD_ARGS); break;
+      //case 8:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  8>(FWD_ARGS); break;
+      //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
+      //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
+    }
+#undef FWD_TMPL
+#undef FWD_ARGS
 }
 
 // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
@@ -5163,111 +5455,120 @@ kernel void kernel_flash_attn_ext_vec(
 
 typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  64, 64, 2>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 64, 64, 2>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 64, 64, 2>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  96, 96, 4>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 96, 96, 4>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 96, 96, 4>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 1>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 128, 128, 1>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 128, 128, 1>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 192, 2>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 192, 2>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 192, 2>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 128, 2>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 128, 2>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 128, 2>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  256, 256, 1>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 256, 256, 1>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 256, 256, 1>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  576, 512, 2>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 576, 512, 2>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 576, 512, 2>;
 
 #undef FA_TYPES
 
-kernel void kernel_flash_attn_ext_reduce(
-        constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+constant int32_t FC_flash_attn_ext_vec_reduce_DV  [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
+constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
+
+kernel void kernel_flash_attn_ext_vec_reduce(
+        constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
         device  const char * htmp,
         device        char * dst,
         uint   tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+#define NWG (FC_flash_attn_ext_vec_reduce_NWG)
+#define DV  (FC_flash_attn_ext_vec_reduce_DV)
+
     const uint64_t rid = tgpig;
 
-    const short nwg = 32;
     const short iwg = tiisg;
-    const short DV  = args.ne20;
-    const short DV4 = DV/4;
 
-    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
-    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*nwg;
-    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;
+    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*NWG;
 
-    float S = ss[rid*(2*nwg) + 2*iwg + 0];
-    float M = ss[rid*(2*nwg) + 2*iwg + 1];
+    float S = ss[rid*(2*NWG) + 2*iwg + 0];
+    float M = ss[rid*(2*NWG) + 2*iwg + 1];
 
     const float m  = simd_max(M);
     const float ms = exp(M - m);
 
     S = 1.0f/simd_sum(S*ms);
 
-    for (int i = sgitg; i < DV4; i += nwg) {
-        const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+    const short DV4 = DV/4;
+
+    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
+    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;
+
+    for (short i = sgitg; i < DV4; i += NWG) {
+        const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
 
         if (iwg == 0) {
             dst4[i] = v*S;
         }
     }
+
+#undef NWG
+#undef DV
 }
 
 template<typename T>
@@ -7397,7 +7698,7 @@ kernel void kernel_set_rows_f(
     const int32_t i10 = i01;
     const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
 
-    device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
+          device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
     const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
 
     for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
@@ -7496,18 +7797,20 @@ kernel void kernel_mul_mm(
 
         #pragma unroll(4)
         for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
+            simdgroup_barrier(mem_flags::mem_none);
+
             #pragma unroll(4)
             for (short i = 0; i < 4; i++) {
                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
 
-            simdgroup_barrier(mem_flags::mem_none);
-
             #pragma unroll(2)
             for (short i = 0; i < 2; i++) {
                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
 
+            simdgroup_barrier(mem_flags::mem_none);
+
             #pragma unroll(8)
             for (short i = 0; i < 8; i++){
                 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
index e8e6e1000b2eb108dd2ecedffed954891c907fe4..cca86a8ce8842b2adb95033ad2e68501b317c10a 100644 (file)
@@ -6501,8 +6501,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
     }
 
-    for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
-        for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
+    for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
+        for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
             if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
             if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA