]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add quantized FA support (llama/10149)
authorGeorgi Gerganov <redacted>
Wed, 6 Nov 2024 08:24:23 +0000 (10:24 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
* metal : add quantized FA (vec) support

ggml-ci

* metal : add quantized FA (non-vec) support

* metal : fix support check

ggml-ci

* metal : clean-up

* metal : clean-up (cont)

* metal : fix shared memory calc + reduce smem + comments

* metal : float-correctness

* metal : minor [no ci]

ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index f9bd6faa49a34d1824b84774e5d886bd1d55623a..aee354cdd04168252abad104ce75c678e4a3788d 100644 (file)
@@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
-  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
-  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        support_simdgroup_mm);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
@@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_LEAKY_RELU:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
-            if (op->src[1]->type != GGML_TYPE_F16) {
-                return false;
-            }
-            if (op->src[2]->type != GGML_TYPE_F16) {
-                return false;
-            }
-            if (op->src[0]->ne[0] == 256) {
+            if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
             return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
                 GGML_ASSERT(ne11 % 32 == 0);
 
                 GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT(src1->type == src2->type);
 
                 GGML_ASSERT(ggml_are_same_shape (src1, src2));
 
@@ -2869,26 +2944,154 @@ static void ggml_metal_encode_node(
                 bool use_vec_kernel = false;
 
                 if (ne01 >= 4 || (ne00%128 != 0)) {
-                    switch (ne00) {
-                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
-                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
-                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
-                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
-                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                  //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                    switch (src1->type) {
+                        case GGML_TYPE_F16:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
                         default:
-                                  {
-                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                      GGML_ABORT("add template specialization for this size");
-                                  }
+                            {
+                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                GGML_ABORT("add template specialization for this type");
+                            }
                     }
                 } else {
                     use_vec_kernel = true;
 
                     switch (ne00) {
-                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                                  //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                        case 128:
+                            {
+                                switch (src1->type) {
+                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
+                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
+                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
+                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
+                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
+                                    default:
+                                        {
+                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                            GGML_LOG_ERROR("add template specialization for this type\n");
+                                            GGML_ABORT("add template specialization for this type");
+                                        }
+                                }
+                            } break;
+                        case 256:
+                            {
+                                switch (src1->type) {
+                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
+                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
+                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
+                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
+                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
+                                    default:
+                                        {
+                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                            GGML_LOG_ERROR("add template specialization for this type\n");
+                                            GGML_ABORT("add template specialization for this type");
+                                        }
+                                }
+                            } break;
                         default:
                                   {
                                       GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
                     GGML_ASSERT(nqptg  % 8  == 0);
                     GGML_ASSERT(ncpsg  % 32 == 0);
 
+                    // 16*32*(nsg)
+                    // the shared memory needed for the simdgroups to load the KV cache
+                    // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
+                    //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+
                     int64_t nsgmax = 2;
 
                     while (true) {
-                        const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+                        const size_t smem = FATTN_SMEM(nsgmax);
                         if (smem > device.maxThreadgroupMemoryLength) {
                             break;
                         }
@@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
                     // simdgroups per threadgroup (a.k.a. warps)
                     const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
 
-                    const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+                    const size_t smem = FATTN_SMEM(nsg);
 
-                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
                     GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-
-                    [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 } else {
-                    // half1x4 kernel
+                    // half4x4 kernel
                     const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
                     const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
 
@@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
                     GGML_ASSERT(nqptg  % 1  == 0);
                     GGML_ASSERT(ncpsg  % 32 == 0);
 
+                    // ne00 + 2*ncpsg*(nsg)
+                    // for each query, we load it as f16 in shared memory (ne00)
+                    // and store the attention scores (nqptg x ncpsg) as f32
+                    //
+                    // 2*ne00*(nsg)
+                    // each simdgroup has a full f32 head vector in shared mem to accumulate results
+                    //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
+
+                    int64_t nsgmax = 2;
+
+                    while (true) {
+                        const size_t smem = FATTN_SMEM(nsgmax);
+                        if (smem > device.maxThreadgroupMemoryLength) {
+                            break;
+                        }
+                        nsgmax *= 2;
+                    }
+                    nsgmax /= 2;
+
                     // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
 
                     int64_t nsg = 1;
                     while (nsg <= nsgt) {
@@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
                     }
                     nsg /= 2;
 
-                    const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+                    const size_t smem = FATTN_SMEM(nsg);
 
-                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
                     GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-                    [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
index ff9d3749042be09bebf6bc55c857510eb8eba266..b9ea9f08ed0667ef5b4835b17475e183ea04ee89 100644 (file)
@@ -2723,46 +2723,10 @@ kernel void kernel_leaky_relu_f32(
     dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
 }
 
-typedef void (flash_attn_ext_f16_t)(
-        device const  char * q,
-        device const  char * k,
-        device const  char * v,
-        device const  char * mask,
-        device       float * dst,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant   int64_t & ne13,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant  uint64_t & nb21,
-        constant  uint64_t & nb22,
-        constant  uint64_t & nb23,
-        constant  uint64_t & nb31,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
-        constant     float & logit_softcap,
-        threadgroup   half * shared,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        uint3  tpitg[[thread_position_in_threadgroup]],
-        uint3    ntg[[threads_per_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]);
-
 // ref: https://arxiv.org/pdf/2307.08691.pdf
-template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_f16(
+// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
+kernel void kernel_flash_attn_ext(
         device const  char * q,
         device const  char * k,
         device const  char * v,
@@ -2800,15 +2764,15 @@ kernel void kernel_flash_attn_ext_f16(
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
-    const short iq3 = tgpig[2];
-    const short iq2 = tgpig[1];
-    const short iq1 = tgpig[0]*Q;
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0]*Q;
 
-    const short D4 = D/4;
-    const short D8 = D/8;
-  //const short Q8 = Q/8;
-    const short NW = N_SIMDWIDTH;
-    const short SH = (C + Q); // shared memory per simdgroup in (half)
+    const short D4  = D/4;
+    const short D8  = D/8;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short SH  = (C + Q); // shared memory per simdgroup in (half)
 
     const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
     const short TF = T/2;        // shared memory size per query in (float)
@@ -2818,6 +2782,9 @@ kernel void kernel_flash_attn_ext_f16(
     threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4
     threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
 
+    threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
+    threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
+
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
     simdgroup_half8x8 lo[D8];
 
@@ -2849,25 +2816,28 @@ kernel void kernel_flash_attn_ext_f16(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        float S[Q] = { [0 ... Q-1] = 0.0h };
+        float S[Q] = { [0 ... Q-1] = 0.0f };
         float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
 
+        // thread indices inside the simdgroup
+        const short tx = tiisg%4;
+        const short ty = tiisg/4;
+
         // assume K and V are same shape
         const short ne22 = ne12;
         const short ne23 = ne13;
 
-        // broadcast
+        // broadcast k
         const short rk2 = ne02/ne12;
         const short rk3 = ne03/ne13;
 
-        const short rv2 = ne02/ne22;
-        const short rv3 = ne03/ne23;
-
-        // k indices
         const short ik2 = iq2/rk2;
         const short ik3 = iq3/rk3;
 
-        // v indices
+        // broadcast v
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
         const short iv2 = iq2/rv2;
         const short iv3 = iq3/rv3;
 
@@ -2906,13 +2876,59 @@ kernel void kernel_flash_attn_ext_f16(
                 for (short cc = 0; cc < C/8; ++cc) {
                     simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
 
-                    device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+                    // this is compile-time check, so it does not have runtime overhead
+                    if (is_same<block_q, half4x4>::value) {
+                        // we can read directly from global memory
+                        device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+                        for (short i = 0; i < D8; ++i) {
+                            simdgroup_half8x8 mk;
+                            simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+
+                            if (D16%4 == 0) {
+                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+                                half4x4 tmp;
+                                dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+                                skv4[4*ty + tx] = tmp;
 
-                    for (short i = 0; i < D8; ++i) {
-                        simdgroup_half8x8 mk;
-                        simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+#pragma unroll
+                                for (short k = 0; k < 4; ++k) {
+                                    simdgroup_half8x8 mk;
+
+                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    half4x4 tmp;
+                                    dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+                                    skv4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    simdgroup_half8x8 mk;
+
+                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            }
+                        }
                     }
 
                     simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@@ -2977,16 +2993,61 @@ kernel void kernel_flash_attn_ext_f16(
             // O = O + (Q*K^T)*V
             {
                 for (short cc = 0; cc < C/8; ++cc) {
-                    device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+                    simdgroup_float8x8 ms;
+                    simdgroup_load(ms, ss + 8*cc, TF, 0, false);
+
+                    if (is_same<block_q, half4x4>::value) {
+                        // we can read directly from global memory
+                        device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+#pragma unroll
+                        for (short i = 0; i < D8; ++i) {
+                            simdgroup_half8x8 mv;
+                            simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
 
-                    for (short i = 0; i < D8; ++i) {
-                        simdgroup_half8x8 mk;
-                        simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+
+                            if (D16%4 == 0) {
+                                // no need for bound checks
+                                half4x4 tmp;
+                                dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+                                skv4[4*ty + tx] = tmp;
 
-                        simdgroup_float8x8 mv;
-                        simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                        simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+#pragma unroll
+                                for (short k = 0; k < 4; ++k) {
+                                    simdgroup_half8x8 mv;
+
+                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    half4x4 tmp;
+                                    dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+                                    skv4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    simdgroup_half8x8 mv;
+
+                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            }
+                        }
                     }
                 }
             }
@@ -3003,7 +3064,7 @@ kernel void kernel_flash_attn_ext_f16(
 
     // reduce the warps sequentially
     for (short sg = 1; sg < nsg; ++sg) {
-        float S = { 0.0h };
+        float S = { 0.0f };
         float M = { -FLT_MAX/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3082,15 +3143,54 @@ kernel void kernel_flash_attn_ext_f16(
     }
 }
 
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
-
-template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec_f16(
+typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
+
+// NOTE: can use half instead of float precision for some extra perf
+// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
+kernel void kernel_flash_attn_ext_vec(
         device const  char * q,
         device const  char * k,
         device const  char * v,
@@ -3128,36 +3228,27 @@ kernel void kernel_flash_attn_ext_vec_f16(
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
-    const short iq3 = tgpig[2];
-    const short iq2 = tgpig[1];
-    const short iq1 = tgpig[0];
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0];
 
-    const short D4 = D/4;
-    const short NW = N_SIMDWIDTH;
-    const short SH = (C + Q); // shared memory per simdgroup in (half)
+    const short D4  = D/4;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short NW4 = NW/4;
+    const short SH  = C; // shared memory per simdgroup in (half)
 
     const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
 
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = iq2;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = pow(base, exp);
-    }
-
-  //threadgroup half   * sq  = (threadgroup half   *) (shared +              0*D); // holds the query data
-    threadgroup half4  * sq4 = (threadgroup half4  *) (shared +              0*D); // same as above but in half4
-    threadgroup float  * ss  = (threadgroup float  *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
-    threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
-    threadgroup half4  * sr4 = (threadgroup half4  *) (shared +   sgitg*D  + 1*T); // scratch buffer for the results
+  //threadgroup half     * sq   = (threadgroup half     *) (shared +              0*D); // holds the query data
+    threadgroup half4    * sq4  = (threadgroup half4    *) (shared +              0*D); // same as above but in half4
+    threadgroup half4x4  * sq44 = (threadgroup half4x4  *) (shared +              0*D); // same as above but in half4x4
+    threadgroup float    * ss   = (threadgroup float    *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
+    threadgroup float4   * ss4  = (threadgroup float4   *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+    threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D  + Q*T); // scratch buffer for the results
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    half4 lo[D4/NW];
+    float4x4 lo[D16/NW4];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3171,8 +3262,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
     }
 
     // zero out lo
-    for (short i = tiisg; i < D4; i += NW) {
-        lo[i/NW] = 0.0h;
+    for (short i = 0; i < D16/NW4; i += NW4) {
+        lo[i] = float4x4(0.0f);
     }
 
     // zero out shared memory SH
@@ -3183,38 +3274,52 @@ kernel void kernel_flash_attn_ext_vec_f16(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        float S = { 0.0h };
-        float M = { -FLT_MAX/2 };
+        float S = 0.0f;
+        float M = -FLT_MAX/2;
+
+        // thread indices inside the simdgroup
+        const short tx = tiisg%8;
+        const short ty = tiisg/8;
 
         // assume K and V are same shape
         const short ne22 = ne12;
         const short ne23 = ne13;
 
-        // broadcast
+        // broadcast k
         const short rk2 = ne02/ne12;
         const short rk3 = ne03/ne13;
 
+        const short ik2 = iq2/rk2;
+        const short ik3 = iq3/rk3;
+
+        // broadcast v
         const short rv2 = ne02/ne22;
         const short rv3 = ne03/ne23;
 
-        // k indices
-        const short ik2 = iq2 / rk2;
-        const short ik3 = iq3 / rk3;
-
-        // v indices
-        const short iv2 = iq2 / rv2;
-        const short iv3 = iq3 / rv3;
+        const short iv2 = iq2/rv2;
+        const short iv3 = iq3/rv3;
 
         // load the queries from shared memory into local memory
-        float4 mq[D4/NW];
+        float4x4 mq[D16/NW4];
 
-        for (short ii = 0; ii < D4; ii += NW) {
-            short i = ii + tiisg;
-            mq[ii/NW] = (float4) sq4[i];
+        for (short ii = 0; ii < D16; ii += NW4) {
+            mq[ii/NW4] = (float4x4) sq44[ii + tx];
         }
 
         // pointer to the mask
-        device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+        device const half * mp = (device const half *) (mask + iq1*nb31);
+
+        float slope = 1.0f;
+
+        // ALiBi
+        if (max_bias > 0.0f) {
+            const uint32_t h = iq2;
+
+            const float base = h < n_head_log2 ? m0 : m1;
+            const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+            slope = pow(base, exp);
+        }
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
@@ -3226,47 +3331,54 @@ kernel void kernel_flash_attn_ext_vec_f16(
 
             // Q*K^T
             {
-#pragma unroll
+                // each simdgroup processes 1 query and 4 keys
                 for (short cc = 0; cc < C/4; ++cc) {
-                    float4 mqk = { 0.0h };
+                    float mqk = 0.0;
 
-                    device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+                    device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
 
 #pragma unroll
-                    for (short ii = 0; ii < D4; ii += NW) {
-                        const short i = ii + tiisg;
+                    for (short ii = 0; ii < D16; ii += NW4) {
+                        const short i = ii + tx;
 
                         float4x4 mk;
-                        mk[0] = (float4) pk4[i + 0*(nb11/8)];
-                        mk[1] = (float4) pk4[i + 1*(nb11/8)];
-                        mk[2] = (float4) pk4[i + 2*(nb11/8)];
-                        mk[3] = (float4) pk4[i + 3*(nb11/8)];
+                        dequantize_func(pk + i/nl, i%nl, mk);
 
-                        mqk += (float4) (mq[ii/NW] * mk);
+                        mqk +=
+                            dot(mq[ii/NW4][0], mk[0]) +
+                            dot(mq[ii/NW4][1], mk[1]) +
+                            dot(mq[ii/NW4][2], mk[2]) +
+                            dot(mq[ii/NW4][3], mk[3]);
                     }
 
-                    // reduce the results from the threads in the simdgroup
-                    mqk += simd_shuffle_down(mqk, 16);
-                    mqk += simd_shuffle_down(mqk,  8);
+                    // simdgroup reduce
+                    // [ 0 ..  7] -> [ 0]
+                    // [ 8 .. 15] -> [ 8]
+                    // [16 .. 23] -> [16]
+                    // [24 .. 31] -> [24]
+                  //mqk += simd_shuffle_down(mqk, 16);
+                  //mqk += simd_shuffle_down(mqk,  8);
                     mqk += simd_shuffle_down(mqk,  4);
                     mqk += simd_shuffle_down(mqk,  2);
                     mqk += simd_shuffle_down(mqk,  1);
 
                     // mqk = mqk*scale + mask*slope
-                    if (tiisg == 0) {
+                    if (tx == 0) {
                         mqk *= scale;
 
                         if (logit_softcap != 0.0f) {
                             mqk = logit_softcap*precise::tanh(mqk);
                         }
 
-                        mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
+                        mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
 
-                        ss4[cc] = mqk;
+                        ss[4*cc + ty] = mqk;
                     }
                 }
             }
 
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
             // online softmax
             {
                 const short p = tiisg;
@@ -3286,29 +3398,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
 
                 // O = diag(ms)*O
 #pragma unroll
-                for (short ii = 0; ii < D4; ii += NW) {
-                    lo[ii/NW] *= ms;
+                for (short ii = 0; ii < D16; ii += NW4) {
+                    lo[ii/NW4] *= ms;
                 }
             }
 
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
             // O = O + (Q*K^T)*V
             {
 #pragma unroll
                 for (short cc = 0; cc < C/4; ++cc) {
-                    device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+                    device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+
+                    const float4x4 lss(ss[4*cc + ty]);
 
 #pragma unroll
-                    for (short ii = 0; ii < D4; ii += NW) {
-                        const short i = ii + tiisg;
+                    for (short ii = 0; ii < D16; ii += NW4) {
+                        const short i = ii + tx;
+
+                        float4x4 mv;
+                        dequantize_func(pv4 + i/nl, i%nl, mv);
 
-                        lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
-                        lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
-                        lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
-                        lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+                        lo[ii/NW4] += mv*lss;
                     }
                 }
             }
-
         }
 
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
@@ -3318,10 +3433,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
         }
     }
 
+    // simdgroup reduce
+    // [ 0,  8, 16, 24] -> [ 0]
+    // [ 1,  9, 17, 25] -> [ 1]
+    // [ 2, 10, 18, 26] -> [ 2]
+    // [ 3, 11, 19, 27] -> [ 3]
+    // [ 4, 12, 20, 28] -> [ 4]
+    // [ 5, 13, 21, 29] -> [ 5]
+    // [ 6, 14, 22, 30] -> [ 6]
+    // [ 7, 15, 23, 31] -> [ 7]
+    for (short ii = 0; ii < D16; ii += NW4) {
+        lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
+        lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0],  8);
+
+        lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
+        lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1],  8);
+
+        lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
+        lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2],  8);
+
+        lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
+        lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3],  8);
+    }
+
     // store results to shared memory
-    for (short ii = 0; ii < D4; ii += NW) {
-        short i = ii + tiisg;
-        sr4[i] = lo[ii/NW];
+    for (short i = tiisg; i < D16; i += NW4) {
+        sr44[i] = lo[i/NW4];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3348,30 +3485,41 @@ kernel void kernel_flash_attn_ext_vec_f16(
             }
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            for (short ii = 0; ii < D4; ii += NW) {
-                short i = ii + tiisg;
-                sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+            for (short i = tiisg; i < D16; i += NW) {
+                sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
             }
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    device float4 * dst4 = (device float4 *) dst;
+    device float4x4 * dst44 = (device float4x4 *) dst;
 
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
         const float S = ss[0];
 
-        for (short ii = 0; ii < D4; ii += NW) {
-            short i = ii + tiisg;
-            dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+        for (short i = tiisg; i < D16; i += NW) {
+            dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
         }
     }
 }
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4,    1, dequantize_f16,  128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4,    1, dequantize_f16,  256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
 
 template<typename T0, typename T1>
 kernel void kernel_cpy(