]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : optimize FA kernels (llama/10171)
authorGeorgi Gerganov <redacted>
Fri, 8 Nov 2024 11:47:22 +0000 (13:47 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
* ggml : add ggml_flash_attn_ext_get_prec

* metal : use F16 precision in FA kernels

ggml-ci

* metal : minor clean-up

* metal : compile-guard bf16 FA kernels

ggml-ci

* build : remove obsolete compile flag [no ci]

* metal : prevent int overflows [no ci]

* cuda : disable BF16 FA

ggml-ci

* metal : fix BF16 requirement for FA kernels

ggml-ci

* make : clean-up [no ci]

ggml/include/ggml.h
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal
ggml/src/ggml.c

index 0d143d2fe0a20103e8feb99835bf3860fdd09a1c..73ede181331ed14d957db17383316be695d542bd 100644 (file)
@@ -1746,6 +1746,9 @@ extern "C" {
             struct ggml_tensor * a,
             enum ggml_prec       prec);
 
+    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
+            const struct ggml_tensor * a);
+
     // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
index e27c8e87d5023ab4e45fb4ed4c13d1f1d69e15b1..357cee660cd38f2ea80840ed21c9cb63056cc9e8 100644 (file)
@@ -3159,6 +3159,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
 #endif
+            if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
+                return false;
+            }
             if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
                 return true;
             }
index 83e5589a1cc244e3182a183f580b87fb13d0adb6..0e7ebbc5393523a823e48e6b3d16cff88d0b6c83 100644 (file)
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (precision != GGML_PREC_DEFAULT) {
+    if (prec != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
             switch (Q->ne[0]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= CC_OFFSET_AMD) {
-        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
-        if (precision == GGML_PREC_DEFAULT) {
+        if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
         } else if(Q->ne[0] <= 128) {
index f13adee38435062a0978449a76de84916f02ca81..e19397fd2de70702c4a8807ccd4c7a36ec3aaa94 100644 (file)
@@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
             id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
             kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
+            GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+                    (int) kernel->pipeline.threadExecutionWidth); \
             [metal_function release]; \
             if (error) { \
                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
@@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
     const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
     const uint64_t nb21 = src2 ? src2->nb[1] : 0;
     const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-    const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+    const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
 
     const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
     const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
                                               }
                                 }
                             } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
                         case GGML_TYPE_Q4_0:
                             {
                                 switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
                             {
                                 switch (src1->type) {
                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
                             {
                                 switch (src1->type) {
                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
                 [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
                 [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
                 [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17];
-                [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18];
-                [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19];
-                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20];
-                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21];
-                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22];
-                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23];
-                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24];
-                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25];
-                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26];
-                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27];
-                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
+                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17];
+                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18];
+                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19];
+                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20];
+                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21];
+                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22];
+                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23];
+                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24];
+                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
                     GGML_ASSERT(nqptg  % 8  == 0);
                     GGML_ASSERT(ncpsg  % 32 == 0);
 
+                    // 2*(2*ncpsg + nqptg)*(nsg)
+                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+                    //
                     // 16*32*(nsg)
                     // the shared memory needed for the simdgroups to load the KV cache
                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
 
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
 
                     // ne00 + 2*ncpsg*(nsg)
                     // for each query, we load it as f16 in shared memory (ne00)
-                    // and store the attention scores (nqptg x ncpsg) as f32
+                    // and store the soft_max values and the mask
                     //
-                    // 2*ne00*(nsg)
-                    // each simdgroup has a full f32 head vector in shared mem to accumulate results
+                    // ne00*(nsg)
+                    // each simdgroup has a full f16 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
 
index 16b5da3ff3f022b475e776e3c7220a9861b1e2ec..edce741088f05d1ef4875940021cc0b8f0eaad29 100644 (file)
@@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
     const ushort mask0 = il ? 0x00F0 : 0x000F;
     const ushort mask1 = mask0 << 8;
 
-    for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
-        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
+        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
     }
+
+    reg = (type4x4) reg_f;
 }
 
 template <typename type4x4>
@@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
     const ushort mask0 = il ? 0x00F0 : 0x000F;
     const ushort mask1 = mask0 << 8;
 
-    for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
-        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
+        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
     }
+
+    reg = (type4x4) reg_f;
 }
 
 template <typename type4x4>
@@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
     const int gh_mv = il ? 12 : 0;
     const int gh_bk = il ?  0 : 4;
 
+    float4x4 reg_f;
+
     for (int i = 0; i < 8; i++) {
         // extract the 5-th bits for x0 and x1
         const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
@@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
         const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
         const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 
-        reg[i/2][2*(i%2)+0] = d * x0 + md;
-        reg[i/2][2*(i%2)+1] = d * x1 + md;
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
     }
+
+    reg = (type4x4) reg_f;
 }
 
 template <typename type4x4>
@@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
     const int gh_mv = il ? 12 : 0;
     const int gh_bk = il ?  0 : 4;
 
+    float4x4 reg_f;
+
     for (int i = 0; i < 8; i++) {
         // extract the 5-th bits for x0 and x1
         const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
@@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
         const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
         const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 
-        reg[i/2][2*(i%2)+0] = d * x0 + m;
-        reg[i/2][2*(i%2)+1] = d * x1 + m;
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
     }
+
+    reg = (type4x4) reg_f;
 }
 
 template <typename type4x4>
@@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
     device const int8_t * qs = ((device const int8_t *)xb->qs);
     const half d = xb->d;
 
+    float4x4 reg_f;
+
     for (int i = 0; i < 16; i++) {
-        reg[i/4][i%4] = (qs[i + 16*il] * d);
+        reg_f[i/4][i%4] = (qs[i + 16*il] * d);
     }
+
+    reg = (type4x4) reg_f;
 }
 
 template <typename type4x4>
@@ -2755,44 +2775,65 @@ kernel void kernel_leaky_relu_f32(
 }
 
 // ref: https://arxiv.org/pdf/2307.08691.pdf
-// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
+template<
+    typename q_t,     // query types in shared memory
+    typename q4_t,
+    typename q8x8_t,
+    typename k_t,     // key types in shared memory
+    typename k4x4_t,
+    typename k8x8_t,
+    typename v_t,     // value types in shared memory
+    typename v4x4_t,
+    typename v8x8_t,
+    typename qk_t,    // Q*K types
+    typename qk8x8_t,
+    typename s_t,     // soft-max types
+    typename s8x8_t,
+    typename o_t,     // attention accumulation types
+    typename o4_t,
+    typename o8x8_t,
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 8,    // queries per threadgroup
+    short KV = 8,    // key/value processed per each simdgroup
+    short C  = 32>   // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         device const  char * q,
         device const  char * k,
         device const  char * v,
         device const  char * mask,
         device       float * dst,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant   int64_t & ne13,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant  uint64_t & nb21,
-        constant  uint64_t & nb22,
-        constant  uint64_t & nb23,
-        constant  uint64_t & nb31,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
+        constant   int32_t & ne01,
+        constant   int32_t & ne02,
+        constant   int32_t & ne03,
+        constant  uint32_t & nb01,
+        constant  uint32_t & nb02,
+        constant  uint32_t & nb03,
+        constant   int32_t & ne11,
+        constant   int32_t & ne_12_2, // assume K and V are same shape
+        constant   int32_t & ne_12_3,
+        constant  uint32_t & nb_12_1,
+        constant  uint32_t & nb_12_2,
+        constant  uint32_t & nb_12_3,
+        constant  uint32_t & nb31,
+        constant   int32_t & ne1,
+        constant   int32_t & ne2,
         constant     float & scale,
         constant     float & max_bias,
         constant     float & m0,
         constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant  uint16_t & n_head_log2,
         constant     float & logit_softcap,
         threadgroup   half * shared [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        uint3  tpitg[[thread_position_in_threadgroup]],
-        uint3    ntg[[threads_per_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+        ushort3  tgpig[[threadgroup_position_in_grid]],
+        ushort3    ntg[[threads_per_threadgroup]],
+        ushort   tiisg[[thread_index_in_simdgroup]],
+        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
     const int iq3 = tgpig[2];
@@ -2803,21 +2844,25 @@ kernel void kernel_flash_attn_ext(
     const short D8  = D/8;
     const short D16 = D/16;
     const short NW  = N_SIMDWIDTH;
-    const short SH  = (C + Q); // shared memory per simdgroup in (half)
+    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
-    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
-    const short TF = T/2;        // shared memory size per query in (float)
-    const short T4 = T/4;        // shared memory size per query in (half4)
+    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
+    const short T  = D + 2*TS; // shared memory size per query in (half)
 
-    threadgroup half  * sq  = (threadgroup half  *) (shared +              0*D); // holds the query data
-    threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4
-    threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shared +              0*D); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared +              0*D); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shared +              0*D); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shared +              0*D); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
 
-    threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
-    threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
+    threadgroup k_t    * sk    = (threadgroup k_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+    threadgroup v_t    * sv    = (threadgroup v_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    simdgroup_half8x8 lo[D8];
+    o8x8_t lo[D8];
 
     // load heads from Q to shared memory
     for (short j = sgitg; j < Q; j += nsg) {
@@ -2825,71 +2870,61 @@ kernel void kernel_flash_attn_ext(
 
         for (short i = tiisg; i < D4; i += NW) {
             if (iq1 + j < ne01) {
-                sq4[j*T4 + i] = (half4) q4[i];
+                sq4[j*D4 + i] = (q4_t) q4[i];
             } else {
-                sq4[j*T4 + i] = 0.0h;
+                sq4[j*D4 + i] = (q4_t) 0.0f;
             }
         }
     }
 
     // zero out lo
     for (short i = 0; i < D8; ++i) {
-        lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+        lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
     }
 
     // zero out shared memory SH
     for (short j = 0; j < Q; ++j) {
         for (short i = tiisg; i < SH; i += NW) {
-            ss[j*TF + i] = 0.0f;
+            ss[j*TS + i] = 0.0f;
         }
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        float S[Q] = { [0 ... Q-1] = 0.0f };
-        float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+        half S[Q] = { [0 ... Q-1] = 0.0f };
+        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
 
         // thread indices inside the simdgroup
+        // TODO: see if we can utilize quad-group functions for better performance
+        //       https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
         const short tx = tiisg%4;
         const short ty = tiisg/4;
 
-        // assume K and V are same shape
-        const short ne22 = ne12;
-        const short ne23 = ne13;
-
-        // broadcast k
-        const short rk2 = ne02/ne12;
-        const short rk3 = ne03/ne13;
-
-        const short ik2 = iq2/rk2;
-        const short ik3 = iq3/rk3;
+        // broadcast kv
+        //const short rk2 = ne02/ne12;
+        //const short rk3 = ne03/ne13;
 
-        // broadcast v
-        const short rv2 = ne02/ne22;
-        const short rv3 = ne03/ne23;
-
-        const short iv2 = iq2/rv2;
-        const short iv3 = iq3/rv3;
+        const short ikv2 = iq2/(ne02/ne_12_2);
+        const short ikv3 = iq3/(ne03/ne_12_3);
 
         // load the queries from shared memory into local memory
-        simdgroup_half8x8 mq[D8];
+        q8x8_t mq[D8];
 
         for (short i = 0; i < D8; ++i) {
-            simdgroup_load(mq[i], sq + i*8, T);
+            simdgroup_load(mq[i], sq + i*8, D);
         }
 
-        // pointer to the mask
-        device const half * mp = (device const half *) (mask + iq1*nb31);
+        const bool has_mask = mask != q;
 
-        float slope = 1.0f;
+        half slope = 1.0f;
 
         // ALiBi
         if (max_bias > 0.0f) {
-            const uint32_t h = iq2;
+            const short h = iq2;
 
-            const float base = h < n_head_log2 ? m0 : m1;
-            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+            const half  base = h < n_head_log2 ? m0 : m1;
+            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
             slope = pow(base, exph);
         }
@@ -2902,120 +2937,137 @@ kernel void kernel_flash_attn_ext(
                 break;
             }
 
+            if (has_mask) {
+                // used to detect blocks full of -INF
+                half smax = -INFINITY;
+
+                // load the mask in shared memory
+                for (short j = 0; j < Q; ++j) {
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
+
+                    const half m = pm[ic + tiisg];
+
+                    ss[j*TS + C + tiisg] = m;
+                    smax = max(smax, m);
+                }
+
+                smax = simd_max(smax);
+
+                if (smax == -INFINITY) {
+                    continue;
+                }
+            }
+
             // Q*K^T
             {
                 for (short cc = 0; cc < C/8; ++cc) {
-                    simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 
                     // this is compile-time check, so it does not have runtime overhead
-                    if (is_same<block_q, half4x4>::value) {
+                    if (is_same<kd4x4_t, k4x4_t>::value) {
                         // we can read directly from global memory
-                        device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
+#pragma unroll
                         for (short i = 0; i < D8; ++i) {
-                            simdgroup_half8x8 mk;
-                            simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+                            k8x8_t mk;
+                            simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
                             simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
                         }
                     } else {
                         for (short ii = 0; ii < D16; ii += 4) {
-                            device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
                             if (D16%4 == 0) {
                                 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
-                                half4x4 tmp;
-                                dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
-                                skv4[4*ty + tx] = tmp;
+                                {
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
+                                }
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
 #pragma unroll
                                 for (short k = 0; k < 4; ++k) {
-                                    simdgroup_half8x8 mk;
+                                    k8x8_t mk;
 
-                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
                                     simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 
-                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
                                     simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
                                 }
                             } else {
                                 if (ii + tx < D16) {
-                                    half4x4 tmp;
-                                    dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
-                                    skv4[4*ty + tx] = tmp;
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
                                 }
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
                                 for (short k = 0; k < 4 && ii + k < D16; ++k) {
-                                    simdgroup_half8x8 mk;
+                                    k8x8_t mk;
 
-                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
                                     simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 
-                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
                                     simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
                                 }
                             }
                         }
                     }
 
-                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+                    // cast qk_t -> s_t
+                    //s8x8_t mqks(1.0f);
+                    //simdgroup_multiply(mqks, mqk, mqks);
+                    //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+                    simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
                 }
             }
 
-            // used to detect blocks full of -INF
-            float smax = -INFINITY;
-
             // online softmax
             {
-                float ms[Q];
-
-                for (short j = 0; j < Q; ++j) {
-                    const float m = M[j];
+                for (ushort j = 0; j < Q; ++j) {
+                    const half m = M[j];
 
                     // scale and apply the logitcap / mask
-                    float s = ss[j*TF + tiisg]*scale;
+                    half s = ss[j*TS + tiisg]*scale;
 
                     if (logit_softcap != 0.0f) {
                         s = logit_softcap*precise::tanh(s);
                     }
 
-                    if (mask != q) {
-                        // mqk = mqk + mask*slope
-                        s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
-                    }
+                    // mqk = mqk + mask*slope
+                    s += slope*ss[j*TS + C + tiisg];
 
-                    smax = simd_max(max(smax, s));
                     M[j] = simd_max(max(M[j], s));
 
-                                ms[j] = exp(m - M[j]);
-                    const float vs    = exp(s - M[j]);
+                    const half ms = exp(m - M[j]);
+                    const half vs = exp(s - M[j]);
 
-                    S[j] = S[j]*ms[j] + simd_sum(vs);
+                    S[j] = S[j]*ms + simd_sum(vs);
 
                     // the P matrix from the paper (Q rows, C columns)
-                    ss[j*TF + tiisg] = vs;
-                }
+                    ss[j*TS + tiisg] = vs;
 
-                // create a QxQ diagonal matrix for rescaling the output
-                if (tiisg < Q) {
-                    ss[tiisg*TF + C + tiisg] = ms[tiisg];
+                    // create a QxQ diagonal matrix for rescaling the output
+                    if (tiisg == j) {
+                        ss[j*TS + 2*C + j] = ms;
+                    }
                 }
             }
 
-            // skip -INF blocks
-            if (smax == -INFINITY) {
-                continue;
-            }
-
             // O = diag(ms)*O
             {
-                simdgroup_float8x8 mm;
-                simdgroup_load(mm, ss + C, TF, 0, false);
+                s8x8_t mm;
+                simdgroup_load(mm, ss + 2*C, TS, 0, false);
 
+#pragma unroll
                 for (short i = 0; i < D8; ++i) {
                     simdgroup_multiply(lo[i], mm, lo[i]);
                 }
@@ -3024,57 +3076,59 @@ kernel void kernel_flash_attn_ext(
             // O = O + (Q*K^T)*V
             {
                 for (short cc = 0; cc < C/8; ++cc) {
-                    simdgroup_float8x8 ms;
-                    simdgroup_load(ms, ss + 8*cc, TF, 0, false);
+                    s8x8_t ms;
+                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
 
-                    if (is_same<block_q, half4x4>::value) {
+                    if (is_same<vd4x4_t, v4x4_t>::value) {
                         // we can read directly from global memory
-                        device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 #pragma unroll
                         for (short i = 0; i < D8; ++i) {
-                            simdgroup_half8x8 mv;
-                            simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
+                            v8x8_t mv;
+                            simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
 
                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
                         }
                     } else {
                         for (short ii = 0; ii < D16; ii += 4) {
-                            device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
                             if (D16%4 == 0) {
                                 // no need for bound checks
-                                half4x4 tmp;
-                                dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
-                                skv4[4*ty + tx] = tmp;
+                                {
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
+                                }
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
 #pragma unroll
                                 for (short k = 0; k < 4; ++k) {
-                                    simdgroup_half8x8 mv;
+                                    v8x8_t mv;
 
-                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
 
-                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
                                 }
                             } else {
                                 if (ii + tx < D16) {
-                                    half4x4 tmp;
-                                    dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
-                                    skv4[4*ty + tx] = tmp;
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
                                 }
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
                                 for (short k = 0; k < 4 && ii + k < D16; ++k) {
-                                    simdgroup_half8x8 mv;
+                                    v8x8_t mv;
 
-                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
 
-                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
                                 }
                             }
@@ -3087,23 +3141,23 @@ kernel void kernel_flash_attn_ext(
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         for (short j = 0; j < Q; ++j) {
             if (tiisg == 0) {
-                ss[j*TF + 0] = S[j];
-                ss[j*TF + 1] = M[j];
+                ss[j*TS + 0] = S[j];
+                ss[j*TS + 1] = M[j];
             }
         }
     }
 
     // reduce the warps sequentially
-    for (short sg = 1; sg < nsg; ++sg) {
-        float S = { 0.0f };
-        float M = { -FLT_MAX/2 };
+    for (ushort sg = 1; sg < nsg; ++sg) {
+        half S = { 0.0f };
+        half M = { -__FLT16_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // each simdgroup stores its output to shared memory, reusing sq
         if (sgitg == sg) {
             for (short i = 0; i < D8; ++i) {
-                simdgroup_store(lo[i], sq + i*8, T, 0, false);
+                simdgroup_store(lo[i], so + i*8, D, 0, false);
             }
         }
 
@@ -3112,39 +3166,40 @@ kernel void kernel_flash_attn_ext(
         // the first simdgroup accumulates the results from the other simdgroups
         if (sgitg == 0) {
             for (short j = 0; j < Q; ++j) {
-                const float S0 = ss[j*TF +         0];
-                const float S1 = ss[j*TF + sg*SH + 0];
+                const half S0 = ss[j*TS +         0];
+                const half S1 = ss[j*TS + sg*SH + 0];
 
-                const float M0 = ss[j*TF +         1];
-                const float M1 = ss[j*TF + sg*SH + 1];
+                const half M0 = ss[j*TS +         1];
+                const half M1 = ss[j*TS + sg*SH + 1];
 
                 M = max(M0, M1);
 
-                const float ms0 = exp(M0 - M);
-                const float ms1 = exp(M1 - M);
+                const half ms0 = exp(M0 - M);
+                const half ms1 = exp(M1 - M);
 
                 S = S0*ms0 + S1*ms1;
 
                 if (tiisg == 0) {
-                    ss[j*TF + 0] = S;
-                    ss[j*TF + 1] = M;
+                    ss[j*TS + 0] = S;
+                    ss[j*TS + 1] = M;
 
-                    ss[j*TF + C + j        ] = ms0;
-                    ss[j*TF + C + j + sg*SH] = ms1;
+                    ss[j*TS + 2*C + j        ] = ms0;
+                    ss[j*TS + 2*C + j + sg*SH] = ms1;
                 }
             }
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
             {
-                simdgroup_half8x8 t;
-                simdgroup_float8x8 ms0;
-                simdgroup_float8x8 ms1;
+                s8x8_t ms0;
+                s8x8_t ms1;
 
-                simdgroup_load(ms0, ss + C,         TF, 0, false);
-                simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
+                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 
                 for (short i = 0; i < D8; ++i) {
-                    simdgroup_load    (t, sq + i*8, T, 0, false);
+                    o8x8_t t;
+
+                    simdgroup_load    (t, so + i*8, D, 0, false);
                     simdgroup_multiply(t, ms1, t);
 
                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3156,7 +3211,7 @@ kernel void kernel_flash_attn_ext(
     // store result to shared memory (reuse sq)
     if (sgitg == 0) {
         for (short i = 0; i < D8; ++i) {
-            simdgroup_store(lo[i], sq + i*8, T, 0, false);
+            simdgroup_store(lo[i], so + i*8, D, 0, false);
         }
     }
 
@@ -3165,98 +3220,133 @@ kernel void kernel_flash_attn_ext(
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
         for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
-            const float S = ss[j*TF + 0];
+            const float S = ss[j*TS + 0];
 
             for (short i = tiisg; i < D4; i += NW) {
-                dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+                dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
             }
         }
     }
 }
 
-typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
-
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
-
-// NOTE: can use half instead of float precision for some extra perf
-// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+//       template to be able to explore different combinations
+//
+#define FA_TYPES \
+    half,  half4,   simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    float,          simdgroup_float8x8, \
+    float,          simdgroup_float8x8, \
+    half,  half4,   simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256>;
+
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64>;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80>;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96>;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256>;
+#endif
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
+
+#undef FA_TYPES
+
+template<
+    typename q4_t,    // query types in shared memory
+    typename q4x4_t,
+    typename k4x4_t,  // key types in shared memory
+    typename v4x4_t,  // value types in shared memory
+    typename qk_t,    // Q*K types
+    typename s_t,     // soft-max types
+    typename s4_t,
+    typename s4x4_t,
+    typename o4x4_t,  // attention accumulation types
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 1,    // queries per threadgroup
+    short C  = 32>   // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
         device const  char * q,
         device const  char * k,
         device const  char * v,
         device const  char * mask,
         device       float * dst,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant   int64_t & ne13,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant  uint64_t & nb21,
-        constant  uint64_t & nb22,
-        constant  uint64_t & nb23,
-        constant  uint64_t & nb31,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
+        constant   int32_t & ne01,
+        constant   int32_t & ne02,
+        constant   int32_t & ne03,
+        constant  uint32_t & nb01,
+        constant  uint32_t & nb02,
+        constant  uint32_t & nb03,
+        constant   int32_t & ne11,
+        constant   int32_t & ne_12_2, // assume K and V are same shape
+        constant   int32_t & ne_12_3,
+        constant  uint32_t & nb_12_1,
+        constant  uint32_t & nb_12_2,
+        constant  uint32_t & nb_12_3,
+        constant  uint32_t & nb31,
+        constant   int32_t & ne1,
+        constant   int32_t & ne2,
         constant     float & scale,
         constant     float & max_bias,
         constant     float & m0,
         constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant  uint16_t & n_head_log2,
         constant     float & logit_softcap,
         threadgroup   half * shared [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        uint3  tpitg[[thread_position_in_threadgroup]],
-        uint3    ntg[[threads_per_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+        ushort3  tgpig[[threadgroup_position_in_grid]],
+        ushort3  tpitg[[thread_position_in_threadgroup]],
+        ushort3    ntg[[threads_per_threadgroup]],
+        ushort   tiisg[[thread_index_in_simdgroup]],
+        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
     const int iq3 = tgpig[2];
@@ -3267,89 +3357,81 @@ kernel void kernel_flash_attn_ext_vec(
     const short D16 = D/16;
     const short NW  = N_SIMDWIDTH;
     const short NW4 = NW/4;
-    const short SH  = C; // shared memory per simdgroup in (half)
+    const short SH  = 2*C; // shared memory per simdgroup
 
-    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+    const short T = D + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup half     * sq   = (threadgroup half     *) (shared +              0*D); // holds the query data
-    threadgroup half4    * sq4  = (threadgroup half4    *) (shared +              0*D); // same as above but in half4
-    threadgroup half4x4  * sq44 = (threadgroup half4x4  *) (shared +              0*D); // same as above but in half4x4
-    threadgroup float    * ss   = (threadgroup float    *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
-    threadgroup float4   * ss4  = (threadgroup float4   *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
-    threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D  + Q*T); // scratch buffer for the results
+  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shared +                0*D); // holds the query data
+    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shared +                0*D); // same as above but in q4_t
+    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared +                0*D); // same as above but in q4x4_t
+    threadgroup s_t    * ss    = (threadgroup s_t    *) (shared + sgitg*SH     + Q*D); // scratch buffer for attention
+    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shared + sgitg*SH     + Q*D); // same as above but in s4_t
+    threadgroup half   * sm    = (threadgroup half   *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
+    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    float4x4 lo[D16/NW4];
+    o4x4_t lo[D16/NW4];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
 
     for (short i = tiisg; i < D4; i += NW) {
         if (iq1 < ne01) {
-            sq4[i] = (half4) q4[i];
+            sq4[i] = (q4_t) q4[i];
         } else {
-            sq4[i] = 0.0h;
+            sq4[i] = (q4_t) 0.0f;
         }
     }
 
     // zero out lo
     for (short i = 0; i < D16/NW4; i += NW4) {
-        lo[i] = float4x4(0.0f);
+        lo[i] = (o4x4_t) 0.0f;
     }
 
     // zero out shared memory SH
     for (short i = tiisg; i < SH/4; i += NW) {
-        ss4[i] = 0.0h;
+        ss4[i] = (s4_t) 0.0f;
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        float S = 0.0f;
-        float M = -FLT_MAX/2;
+        half S = 0.0f;
+        half M = -__FLT16_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%8;
         const short ty = tiisg/8;
 
-        // assume K and V are same shape
-        const short ne22 = ne12;
-        const short ne23 = ne13;
-
-        // broadcast k
-        const short rk2 = ne02/ne12;
-        const short rk3 = ne03/ne13;
-
-        const short ik2 = iq2/rk2;
-        const short ik3 = iq3/rk3;
+        // broadcast kv
+        //const short rk2 = ne02/ne12;
+        //const short rk3 = ne03/ne13;
 
-        // broadcast v
-        const short rv2 = ne02/ne22;
-        const short rv3 = ne03/ne23;
-
-        const short iv2 = iq2/rv2;
-        const short iv3 = iq3/rv3;
+        const short ikv2 = iq2/(ne02/ne_12_2);
+        const short ikv3 = iq3/(ne03/ne_12_3);
 
         // load the queries from shared memory into local memory
-        float4x4 mq[D16/NW4];
+        q4x4_t mq[D16/NW4];
 
         for (short ii = 0; ii < D16; ii += NW4) {
-            mq[ii/NW4] = (float4x4) sq44[ii + tx];
+            mq[ii/NW4] = sq4x4[ii + tx];
         }
 
+        const bool has_mask = mask != q;
+
         // pointer to the mask
-        device const half * mp = (device const half *) (mask + iq1*nb31);
+        device const half * pm = (device const half *) (mask + iq1*nb31);
 
-        float slope = 1.0f;
+        half slope = 1.0f;
 
         // ALiBi
         if (max_bias > 0.0f) {
-            const uint32_t h = iq2;
+            const short h = iq2;
 
-            const float base = h < n_head_log2 ? m0 : m1;
-            const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+            const half  base = h < n_head_log2 ? m0 : m1;
+            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
-            slope = pow(base, exp);
+            slope = pow(base, exph);
         }
 
         // loop over the KV cache
@@ -3360,20 +3442,24 @@ kernel void kernel_flash_attn_ext_vec(
                 break;
             }
 
+            if (has_mask) {
+                sm[tiisg] = pm[ic + tiisg];
+            }
+
             // Q*K^T
             {
                 // each simdgroup processes 1 query and 4 keys
                 for (short cc = 0; cc < C/4; ++cc) {
-                    float mqk = 0.0;
+                    qk_t mqk = 0.0;
 
-                    device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
 #pragma unroll
                     for (short ii = 0; ii < D16; ii += NW4) {
                         const short i = ii + tx;
 
-                        float4x4 mk;
-                        dequantize_func(pk + i/nl, i%nl, mk);
+                        k4x4_t mk;
+                        deq_k(pk + i/nl_k, i%nl_k, mk);
 
                         mqk +=
                             dot(mq[ii/NW4][0], mk[0]) +
@@ -3401,7 +3487,7 @@ kernel void kernel_flash_attn_ext_vec(
                             mqk = logit_softcap*precise::tanh(mqk);
                         }
 
-                        mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
+                        mqk += sm[4*cc + ty]*slope;
 
                         ss[4*cc + ty] = mqk;
                     }
@@ -3412,20 +3498,18 @@ kernel void kernel_flash_attn_ext_vec(
 
             // online softmax
             {
-                const short p = tiisg;
-
-                const float m = M;
-                const float s = ss[p];
+                const half m = M;
+                const half s = ss[tiisg];
 
                 M = simd_max(max(M, s));
 
-                const float ms = exp(m - M);
-                const float vs = exp(s - M);
+                const half ms = exp(m - M);
+                const half vs = exp(s - M);
 
                 S = S*ms + simd_sum(vs);
 
                 // the P matrix from the paper (Q rows, C columns)
-                ss[p] = vs;
+                ss[tiisg] = vs;
 
                 // O = diag(ms)*O
 #pragma unroll
@@ -3440,18 +3524,18 @@ kernel void kernel_flash_attn_ext_vec(
             {
 #pragma unroll
                 for (short cc = 0; cc < C/4; ++cc) {
-                    device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
-                    const float4x4 lss(ss[4*cc + ty]);
+                    const s4x4_t ms(ss[4*cc + ty]);
 
 #pragma unroll
                     for (short ii = 0; ii < D16; ii += NW4) {
                         const short i = ii + tx;
 
-                        float4x4 mv;
-                        dequantize_func(pv4 + i/nl, i%nl, mv);
+                        v4x4_t mv;
+                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
 
-                        lo[ii/NW4] += mv*lss;
+                        lo[ii/NW4] += mv*ms;
                     }
                 }
             }
@@ -3459,8 +3543,8 @@ kernel void kernel_flash_attn_ext_vec(
 
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         if (tiisg == 0) {
-            ss[0] = S;
-            ss[1] = M;
+            ss[0] = (s_t) S;
+            ss[1] = (s_t) M;
         }
     }
 
@@ -3489,7 +3573,7 @@ kernel void kernel_flash_attn_ext_vec(
 
     // store results to shared memory
     for (short i = tiisg; i < D16; i += NW4) {
-        sr44[i] = lo[i/NW4];
+        sr4x4[i] = lo[i/NW4];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3497,18 +3581,18 @@ kernel void kernel_flash_attn_ext_vec(
     // parallel reduce
     for (short r = nsg/2; r > 0; r >>= 1) {
         if (sgitg < r) {
-            const float S0 = ss[       0];
-            const float S1 = ss[r*SH + 0];
+            const half S0 = ss[       0];
+            const half S1 = ss[r*SH + 0];
 
-            const float M0 = ss[       1];
-            const float M1 = ss[r*SH + 1];
+            const half M0 = ss[       1];
+            const half M1 = ss[r*SH + 1];
 
-            const float M = max(M0, M1);
+            const half M = max(M0, M1);
 
-            const float ms0 = exp(M0 - M);
-            const float ms1 = exp(M1 - M);
+            const half ms0 = exp(M0 - M);
+            const half ms1 = exp(M1 - M);
 
-            const float S = S0*ms0 + S1*ms1;
+            const half S = S0*ms0 + S1*ms1;
 
             if (tiisg == 0) {
                 ss[0] = S;
@@ -3517,7 +3601,7 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
             for (short i = tiisg; i < D16; i += NW) {
-                sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
+                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
             }
         }
 
@@ -3531,26 +3615,45 @@ kernel void kernel_flash_attn_ext_vec(
         const float S = ss[0];
 
         for (short i = tiisg; i < D16; i += NW) {
-            dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
+            dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
         }
     }
 }
 
-typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+           half4,  half4x4, \
+                   half4x4, \
+                   half4x4, \
+    float,                  \
+    half,  half4,  half4x4, \
+                   half4x4
+
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4,    1, dequantize_f16,  128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>;
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 128>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 128>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  256>;
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 256>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 256>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4,    1, dequantize_f16,  256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
+#undef FA_TYPES
 
 template<typename T0, typename T1>
 kernel void kernel_cpy(
index bc034015f470ccbf2bb31cb8205494caba43d342..cd26a361b848f6d1916ffbbc6b3d2e825607975f 100644 (file)
@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
     ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
 }
 
+enum ggml_prec ggml_flash_attn_ext_get_prec(
+        const struct ggml_tensor * a) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
+
+    return (enum ggml_prec) prec_i32;
+}
+
 // ggml_flash_attn_back
 
 struct ggml_tensor * ggml_flash_attn_back(