]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : template-ify some of the kernels (llama/8447)
authorGeorgi Gerganov <redacted>
Sat, 13 Jul 2024 15:32:33 +0000 (18:32 +0300)
committerGeorgi Gerganov <redacted>
Sat, 27 Jul 2024 15:26:12 +0000 (18:26 +0300)
ggml-ci

src/ggml-metal.m
src/ggml-metal.metal

index 79902c9a80616a9bb62af3f368284fb9d4981c5f..b5939efa6c279a7405c38b76a8693dad8c1d3a82 100644 (file)
@@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
-    GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
-    GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
-    GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F32:
                         switch (op->type) {
-                           case GGML_TYPE_F16:
                            case GGML_TYPE_F32:
+                           case GGML_TYPE_F16:
                            case GGML_TYPE_Q8_0:
                            case GGML_TYPE_Q4_0:
                            case GGML_TYPE_Q4_1:
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
                         }
                     case GGML_TYPE_F16:
                         switch (op->type) {
-                           case GGML_TYPE_F16:
                            case GGML_TYPE_F32:
+                           case GGML_TYPE_F16:
                                 return true;
                            default:
                                 return false;
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
-                return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
+                return op->ne[3] == 1;
             }
         default:
             return false;
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
                             // some Metal matrix data types require aligned pointers
                             // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
                             switch (src0->type) {
-                                case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
-                                case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8  == 0); break;
+                                case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
+                                case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
                                 default: break;
                             }
 
@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
                                     GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
 
                                     switch (dstt) {
-                                        case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline;  break;
-                                        case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;  break;
+                                        case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+                                        case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
                                         case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
                                         case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
                                         case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
                             case GGML_TYPE_F16:
                                 {
                                     switch (dstt) {
-                                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
-                                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+                                        case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
                                         default: GGML_ASSERT(false && "not implemented");
                                     };
                                 } break;
index c3503479b35bac3c1ad919b3b13d4e56f1df7176..2a3b0c0a69a74601f32accdad468d6b8d1aafeec 100644 (file)
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
     kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
-#define N_F32_F32 4
+#define N_MV_T_T 4
 
-void kernel_mul_mv_f32_f32_impl(
+template<typename T0, typename T04, typename T1, typename T14>
+void kernel_mul_mv_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
                   uint64_t   nb12,
                    int64_t   ne0,
                    int64_t   ne1,
-                     uint    r2,
-                     uint    r3,
-                     uint3   tgpig,
-                     uint    tiisg) {
-
+                   uint      r2,
+                   uint      r3,
+                   uint3     tgpig,
+                   uint      tiisg) {
     const int64_t r0 = tgpig.x;
-    const int64_t rb = tgpig.y*N_F32_F32;
+    const int64_t rb = tgpig.y*N_MV_T_T;
     const int64_t im = tgpig.z;
 
     const uint i12 = im%ne12;
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
 
     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
 
-    device const float * x = (device const float *) (src0 + offset0);
+    device const T0 * x = (device const T0 *) (src0 + offset0);
 
     if (ne00 < 128) {
-        for (int row = 0; row < N_F32_F32; ++row) {
+        for (int row = 0; row < N_MV_T_T; ++row) {
             int r1 = rb + row;
             if (r1 >= ne11) {
                 break;
             }
 
-            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+            device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
 
             float sumf = 0;
             for (int i = tiisg; i < ne00; i += 32) {
-                sumf += (float) x[i] * (float) y[i];
+                sumf += (T0) x[i] * (T1) y[i];
             }
 
             float all_sum = simd_sum(sumf);
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
             }
         }
     } else {
-        device const float4 * x4 = (device const float4 *)x;
-        for (int row = 0; row < N_F32_F32; ++row) {
+        device const T04 * x4 = (device const T04 *) x;
+        for (int row = 0; row < N_MV_T_T; ++row) {
             int r1 = rb + row;
             if (r1 >= ne11) {
                 break;
             }
 
-            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
-            device const float4 * y4 = (device const float4 *) y;
+            device const T1  * y  = (device const T1  *) (src1 + r1*nb11 + im*nb12);
+            device const T14 * y4 = (device const T14 *) y;
 
             float sumf = 0;
             for (int i = tiisg; i < ne00/4; i += 32) {
-                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+                for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
             }
 
             float all_sum = simd_sum(sumf);
             if (tiisg == 0) {
-                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
             }
         }
     }
 }
 
-[[host_name("kernel_mul_mv_f32_f32")]]
-kernel void kernel_mul_mv_f32_f32(
+template<typename T0, typename T04, typename T1, typename T14>
+kernel void kernel_mul_mv(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
         constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+    kernel_mul_mv_impl<T0, T04, T1, T14>(
+        src0,
+        src1,
+        dst,
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
 }
 
-#define N_F16_F16 4
+typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
 
-kernel void kernel_mul_mv_f16_f16(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>;
 
-    const int64_t r0 = tgpig.x;
-    const int64_t rb = tgpig.y*N_F16_F16;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
-
-    device const half * x = (device const half *) (src0 + offset0);
-
-    if (ne00 < 128) {
-        for (int row = 0; row < N_F16_F16; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00; i += 32) {
-                sumf += (half) x[i] * (half) y[i];
-            }
-
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    } else {
-        device const half4 * x4 = (device const half4 *)x;
-        for (int row = 0; row < N_F16_F16; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
-            device const half4 * y4 = (device const half4 *) y;
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00/4; i += 32) {
-                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
-            }
-
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    }
-}
-
-void kernel_mul_mv_f16_f32_1row_impl(
+template<typename T, typename T4>
+kernel void kernel_mul_mv_1row(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
 
     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
 
-    device const half  * x = (device const half  *) (src0 + offset0);
+    device const T     * x = (device const T     *) (src0 + offset0);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     float sumf = 0;
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         }
     } else {
-        device const half4  * x4 = (device const half4  *) x;
+        device const T4     * x4 = (device const T4     *) x;
         device const float4 * y4 = (device const float4 *) y;
+
         for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
         }
+
         float all_sum = simd_sum(sumf);
+
         if (tiisg == 0) {
-            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         }
     }
 }
 
-[[host_name("kernel_mul_mv_f16_f32_1row")]]
-kernel void kernel_mul_mv_f16_f32_1row(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
-}
-
-#define N_F16_F32 4
-
-void kernel_mul_mv_f16_f32_impl(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb00,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                   int64_t   ne10,
-                   int64_t   ne11,
-                   int64_t   ne12,
-                  uint64_t   nb10,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-                   uint3     tgpig,
-                   uint      tiisg) {
-
-    const int64_t r0 = tgpig.x;
-    const int64_t rb = tgpig.y*N_F16_F32;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
-
-    device const half * x = (device const half *) (src0 + offset0);
-
-    if (ne00 < 128) {
-        for (int row = 0; row < N_F16_F32; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00; i += 32) {
-                sumf += (float) x[i] * (float) y[i];
-            }
-
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    } else {
-        device const half4 * x4 = (device const half4 *)x;
-        for (int row = 0; row < N_F16_F32; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
-            device const float4 * y4 = (device const float4 *) y;
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00/4; i += 32) {
-                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
-            }
+typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
 
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_f16_f32")]]
-kernel void kernel_mul_mv_f16_f32(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
-}
+template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>;
 
 // Assumes row size (ne00) is a multiple of 4
-kernel void kernel_mul_mv_f16_f32_l4(
+template<typename T, typename T4>
+kernel void kernel_mul_mv_l4(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
 
     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
 
-    device const half4 * x4 = (device const half4 *) (src0 + offset0);
+    device const T4 * x4 = (device const T4 *) (src0 + offset0);
 
     for (int r1 = 0; r1 < nrows; ++r1) {
         device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
 
         float sumf = 0;
         for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
         }
 
         float all_sum = simd_sum(sumf);
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
     }
 }
 
+typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
+
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
 //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
 
-kernel void kernel_cpy_f16_f16(
-        device  const half * src0,
-        device        half * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
-    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        dst_data[i00] = src[0];
-    }
-}
-
-kernel void kernel_cpy_f16_f32(
-        device  const half * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
-    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        dst_data[i00] = src[0];
-    }
-}
-
-kernel void kernel_cpy_f32_f16(
-        device const float * src0,
-        device        half * dst,
+template<typename T0, typename T1>
+kernel void kernel_cpy(
+        device  const void * src0,
+        device        void * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 
-    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        dst_data[i00] = src[0];
+        device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = (T1) src[0];
     }
 }
 
-kernel void kernel_cpy_f32_f32(
-        device const float * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
-    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        dst_data[i00] = src[0];
-    }
-}
+template [[host_name("kernel_cpy_f32_f32")]]  kernel kernel_cpy_t kernel_cpy<float,  float>;
+template [[host_name("kernel_cpy_f32_f16")]]  kernel kernel_cpy_t kernel_cpy<float,  half>;
+template [[host_name("kernel_cpy_f16_f16")]]  kernel kernel_cpy_t kernel_cpy<half,   half>;
+template [[host_name("kernel_cpy_f16_f32")]]  kernel kernel_cpy_t kernel_cpy<half,   float>;
 
 kernel void kernel_cpy_f32_q8_0(
         device const float * src0,
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
 }
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
-kernel void kernel_get_rows(
+kernel void kernel_get_rows_q(
         device const  void * src0,
-        device const  char * src1,
+        device const  void * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
-    //const int64_t i = tgpig;
-    //const int64_t r = ((device int32_t *) src1)[i];
-
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
 
     const int64_t i02 = i11;
 
     for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
         float4x4 temp;
-        dequantize_func(
-            ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
         *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
     }
 }
 
-kernel void kernel_get_rows_f32(
-        device const  void * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
-
-    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
-
-    const int64_t i02 = i11;
-
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
-            ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
-    }
-}
-
-kernel void kernel_get_rows_f16(
+template<typename T>
+kernel void kernel_get_rows_f(
         device const  void * src0,
-        device const  char * src1,
+        device const  void * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
 
     const int64_t i02 = i11;
 
     for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
-            ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
     }
 }
 
 kernel void kernel_get_rows_i32(
         device const  void * src0,
-        device const  char * src1,
+        device const  void * src1,
         device     int32_t * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
 
     const int64_t i02 = i11;
 
     for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
-            ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
     }
 }
 
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
 #define SG_MAT_ROW 8
 
 // each block_q contains 16*nl weights
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-void kernel_mul_mm_impl(device const  uchar * src0,
-                        device const  uchar * src1,
-                        device        float * dst,
-                        constant    int64_t & ne00,
-                        constant    int64_t & ne02,
-                        constant   uint64_t & nb01,
-                        constant   uint64_t & nb02,
-                        constant    int64_t & ne12,
-                        constant   uint64_t & nb10,
-                        constant   uint64_t & nb11,
-                        constant   uint64_t & nb12,
-                        constant    int64_t & ne0,
-                        constant    int64_t & ne1,
-                        constant       uint & r2,
-                        constant       uint & r3,
-                        threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                        uint3                 tgpig[[threadgroup_position_in_grid]],
-                        uint                  tiitg[[thread_index_in_threadgroup]],
-                        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_mul_mm(device const  uchar * src0,
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant   uint64_t & nb01,
+                          constant   uint64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant   uint64_t & nb10,
+                          constant   uint64_t & nb11,
+                          constant   uint64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & r2,
+                          constant       uint & r3,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
+    threadgroup T     * sa = (threadgroup T     *)(shared_memory);
     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
 
     const uint r0 = tgpig.y;
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const  uchar * src0,
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
-    simdgroup_half8x8  ma[4];
+    simdgroup_T8x8     ma[4];
     simdgroup_float8x8 mb[2];
     simdgroup_float8x8 c_res[8];
     for (int i = 0; i < 8; i++){
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const  uchar * src0,
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
         // load data and store to threadgroup memory
-        half4x4 temp_a;
+        T4x4 temp_a;
         dequantize_func(x, il, temp_a);
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const  uchar * src0,
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // load matrices from threadgroup memory and conduct outer products
-        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup T     * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
 
         #pragma unroll(4)
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
     }
 }
 
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm(device const  uchar * src0,
-                          device const  uchar * src1,
-                          device        float * dst,
-                          constant    int64_t & ne00,
-                          constant    int64_t & ne02,
-                          constant   uint64_t & nb01,
-                          constant   uint64_t & nb02,
-                          constant    int64_t & ne12,
-                          constant   uint64_t & nb10,
-                          constant   uint64_t & nb11,
-                          constant   uint64_t & nb12,
-                          constant    int64_t & ne0,
-                          constant    int64_t & ne1,
-                          constant       uint & r2,
-                          constant       uint & r3,
-                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                          uint3                 tgpig[[threadgroup_position_in_grid]],
-                          uint                  tiitg[[thread_index_in_threadgroup]],
-                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
-        src0,
-        src1,
-        dst,
-        ne00,
-        ne02,
-        nb01,
-        nb02,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_memory,
-        tgpig,
-        tiitg,
-        sgitg);
-}
-
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
         device const   uchar * src0s,
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
 // get rows
 //
 
-typedef void (get_rows_t)(
-        device const void * src0,
-        device const char * src1,
-        device      float * dst,
-        constant  int64_t & ne00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant  int64_t & ne10,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        uint3, uint, uint3);
-
-//template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
-//template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
-template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_t kernel_get_rows<block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_t kernel_get_rows<block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_rows<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>;
+
+typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_0,    2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_1,    2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // matrix-matrix multiplication
 //
 
-typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<float4x4,      1,     dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half4x4,       1,     dequantize_f16>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_0,    2,     dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_1,    2,     dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_0,    2,     dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_1,    2,     dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q8_0,    2,     dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // indirect matrix-matrix multiplication
@@ -6436,7 +6065,7 @@ void mmv_fn(
     impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
 }
 
-typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
 
 template<mul_mv_impl_fn_t impl_fn>
 kernel void kernel_mul_mv_id(
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
         sgitg);
 }
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
-
-template [[host_name("kernel_mul_mv_id_f32_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;