]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : improve F32, F16 and BF16 mat-vec multiplication (#16057)
authorGeorgi Gerganov <redacted>
Thu, 18 Sep 2025 09:33:45 +0000 (12:33 +0300)
committerGitHub <redacted>
Thu, 18 Sep 2025 09:33:45 +0000 (12:33 +0300)
* metal : improve F32, F16 and BF16 mat-vec multiplication

ggml-ci

* metal : make the NSG a function constant in mul_mv kernels

ggml-ci

ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 5f0478996718ac7ba2ccb2ab843b429e2f0eb336..bada84cef31874928d9db81c978fb4d774bf085b 100644 (file)
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
 }
 
 void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
+    if (!ppls) {
+        return;
+    }
+
     for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
         ggml_metal_pipeline_free(it->second);
     }
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
     // use custom matrix x vector kernel
     switch (tsrc0) {
         case GGML_TYPE_F32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
             {
-                GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
-
-                nsg = 1;
-                nr0 = 1;
-                nr1 = 4;
                 if (ne00 == 4) {
+                    nsg = 1;
                     nr0 = 32;
+                    nr1 = 4;
                     suffix = "_c4";
-                }
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-            {
-                nsg = 1;
-                nr0 = 1;
-                if (op->src[1]->type == GGML_TYPE_F32) {
-                    if (ne00 == 4) {
-                        nr0 = 32;
-                        nr1 = 4;
-                        suffix = "_c4";
-                    } else if (ne11 * ne12 < 4) {
-                        suffix = "_1row";
-                    } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                        suffix = "_l4";
-                        nr1 = ne11;
-                    } else {
-                        nr1 = 4;
-                    }
+                } else if (ne00 % 4 == 0) {
+                    nsg = N_SG_F;
+                    nr0 = N_R0_F;
+                    nr1 = 1;
+                    smem = 32*sizeof(float)*N_R0_F;
+                    suffix = "_4";
                 } else {
-                    nr1 = 4;
+                    nsg = N_SG_F;
+                    nr0 = N_R0_F;
+                    nr1 = 1;
+                    smem = 32*sizeof(float)*N_R0_F;
                 }
             } break;
         case GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
         return res;
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+    ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
 
     ggml_metal_pipeline_set_nr0 (res, nr0);
     ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
     const ggml_type tsrc0 = op->src[0]->type;
     const ggml_type tsrc1 = op->src[1]->type;
 
+    const char * suffix = "";
+
         // use custom matrix x vector kernel
     switch (tsrc0) {
         case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
-                nsg = 1;
-                nr0 = 1;
-            } break;
         case GGML_TYPE_F16:
-            {
-                GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
-                nsg = 1;
-                nr0 = 1;
-            } break;
         case GGML_TYPE_BF16:
             {
-                GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
-                nsg = 1;
-                nr0 = 1;
+                if (ne00 % 4 == 0) {
+                    nsg = N_SG_F;
+                    nr0 = N_R0_F;
+                    nr1 = 1;
+                    smem = 32*sizeof(float)*N_R0_F;
+                    suffix = "_4";
+                } else {
+                    nsg = N_SG_F;
+                    nr0 = N_R0_F;
+                    nr1 = 1;
+                    smem = 32*sizeof(float)*N_R0_F;
+                }
             } break;
         case GGML_TYPE_Q4_0:
             {
@@ -824,7 +823,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
             }
     };
 
-    snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
+    snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -832,7 +831,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
         return res;
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+    ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
 
     ggml_metal_pipeline_set_nr0 (res, nr0);
     ggml_metal_pipeline_set_nr1 (res, nr1);
index c48337f514b421b2978d8532128285cc179c038a..4a3a819fc02c88f32f56dc57b57ea680df4fb763 100644 (file)
@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
 ggml_metal_cv_t ggml_metal_cv_init(void);
 void ggml_metal_cv_free(ggml_metal_cv_t cv);
 
+void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
 void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
 void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool    value, int32_t idx);
 
index 8f83c5cc165075e371b272d8b885781bbcc2fe0e..67f71ace2b44499c591f77ac8a6f7c247fe76693 100644 (file)
@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
     free(cv);
 }
 
+void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
+    [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
+}
+
 void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
     [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
 }
index 0776bb6485cc9741c0d0296f363103a7dd799601..b2572965405b86d2ecbc095d64a659c65ae6fa90 100644 (file)
@@ -8,6 +8,9 @@
 //
 // TODO: for optimal performance, become function of the device and work size
 
+#define N_R0_F 2
+#define N_SG_F 4
+
 #define N_R0_Q4_0 4
 #define N_SG_Q4_0 2
 
@@ -72,6 +75,7 @@
 #define FC_FLASH_ATTN_EXT              100
 #define FC_FLASH_ATTN_EXT_VEC          200
 #define FC_FLASH_ATTN_EXT_VEC_REDUCE   300
+#define FC_MUL_MV                      400
 
 // kernel argument structs
 //
index 839c16894dea18ddb10f92cf3965032afadfb45a..a28206e78683ce20e7335d554adf40a9f6c1822b 100644 (file)
@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
 
         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-        if (op->src[0]->type == GGML_TYPE_Q8_0) {
+        if (op->src[0]->type == GGML_TYPE_F32 ||
+            op->src[0]->type == GGML_TYPE_F16 ||
+            op->src[0]->type == GGML_TYPE_BF16 ||
+            op->src[0]->type == GGML_TYPE_Q8_0) {
             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
         } else {
             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
 
         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-        if (op->src[0]->type == GGML_TYPE_Q8_0) {
+        if (op->src[0]->type == GGML_TYPE_F32 ||
+            op->src[0]->type == GGML_TYPE_F16 ||
+            op->src[0]->type == GGML_TYPE_BF16 ||
+            op->src[0]->type == GGML_TYPE_Q8_0) {
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
         } else {
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
index f34b89e590b791c46738ccf0e5f44b2b45381fca..20ceb1fe04913f390deabfc486feb112f8133272 100644 (file)
@@ -2883,7 +2883,9 @@ static inline void helper_mv_reduce_and_write(
     }
 }
 
-template<typename block_q_type, short NR0, short NSG, short NW, typename args_t>
+constant short FC_mul_mv_nsg  [[function_constant(FC_MUL_MV + 0)]];
+
+template<typename block_q_type, short NR0, short NW, typename args_t>
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -2893,6 +2895,8 @@ void mul_vec_q_n_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
+
     constexpr short NQ = 16;
 
     const int nb = args.ne00/QK4_0;
@@ -2977,7 +2981,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -2989,7 +2993,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -3001,7 +3005,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -3013,10 +3017,10 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<short NR0, short NSG, short NW, typename args_t>
+template<short NR0, short NW, typename args_t>
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -3026,6 +3030,8 @@ void kernel_mul_mv_q8_0_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
+
     constexpr short NQ = 8;
 
     const int nb = args.ne00/QK8_0;
@@ -3097,7 +3103,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
@@ -3404,104 +3410,215 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
 
-#define N_MV_T_T 4
-
-template<typename T0, typename T04, typename T1, typename T14, typename args_t>
-void kernel_mul_mv_impl(
+template<typename T0, typename T1, short NR0, short NW, typename args_t>
+void kernel_mul_mv_t_t_impl(
         args_t args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem,
         uint3  tgpig,
-        ushort tiisg) {
-    const int r0 = tgpig.x;
-    const int rb = tgpig.y*N_MV_T_T;
+        ushort tiisg,
+        ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
+
+    constexpr short NB = 32;
+    constexpr short NF = 8;
+
+    const int nb = args.ne00/NB;
+
+    const int r0 = tgpig.x*NR0;
+    const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const T0 * x = (device const T0 *) (src0 + offset0);
+  //device const T0 * x = (device const T0 *) (src0 + offset0);
+    device const T1 * y = (device const T1 *) (src1 + offset1);
 
-    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+    // pointers to src0 rows
+    device const T0 * ax [NR0];
+    FOR_UNROLL (short row = 0; row < NR0; ++row) {
+        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
-    if (args.ne00 < 128) {
-        for (int row = 0; row < N_MV_T_T; ++row) {
-            int r1 = rb + row;
-            if (r1 >= args.ne11) {
-                break;
-            }
+        ax[row] = (device const T0 *) ((device char *) src0 + offset0);
+    }
 
-            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+    float sumf[NR0] = { 0.f };
 
-            device const T1 * y = (device const T1 *) (src1 + offset1);
+    const short ix = tiisg/(NW/NF);
+    const short il = tiisg%(NW/NF);
 
-            float sumf = 0;
-            for (int i = tiisg; i < args.ne00; i += 32) {
-                sumf += (T0) x[i] * (T1) y[i];
-            }
+    const int ib0 = sgitg*NF + ix;
 
-            float sum_all = simd_sum(sumf);
-            if (tiisg == 0) {
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
-            }
+    T1 yl[NF];
+
+    device const T1 * yb = y + (ib0*NB + il*NF);
+
+    for (int ib = ib0; ib < nb; ib += NSG*NF) {
+        for (short i = 0; i < NF; ++i) {
+            yl[i] = yb[i];
         }
-    } else {
-        device const T04 * x4 = (device const T04 *) x;
-        for (int row = 0; row < N_MV_T_T; ++row) {
-            int r1 = rb + row;
-            if (r1 >= args.ne11) {
-                break;
+
+        for (short row = 0; row < NR0; row++) {
+            device const T0 * xb = ax[row] + (ib*NB + il*NF);
+
+            float sumq = 0.f;
+            FOR_UNROLL (short i = 0; i < NF; ++i) {
+                sumq += xb[i] * yl[i];
             }
 
-            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+            sumf[row] += sumq;
+        }
+
+        yb += NSG*NF*NW;
+    }
 
-            device const T1  * y  = (device const T1  *) (src1 + offset1);
-            device const T14 * y4 = (device const T14 *) y;
+    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
+        for (short row = 0; row < NR0; row++) {
+            sumf[row] += ax[row][i] * y[i];
+        }
+    }
 
-            float sumf = 0;
-            for (int i = tiisg; i < args.ne00/4; i += 32) {
-                sumf += dot((float4) x4[i], (float4) y4[i]);
-            }
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-            float sum_all = simd_sum(sumf);
-            if (tiisg == 0) {
-                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
+    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+}
+
+template<typename T0, typename T1, short NR0, short NW>
+kernel void kernel_mul_mv_t_t(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half,  N_R0_F, N_SIMDWIDTH>;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float,  N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
+#endif
+
+template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
+void kernel_mul_mv_t_t_4_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
+
+    constexpr short NB  = 32;
+    constexpr short NF  = 16;
+    constexpr short NF4 = NF/4;
+
+    const int nb = args.ne00/NB;
+
+    const int r0 = tgpig.x*NR0;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const T1  * y  = (device const T1  *) (src1 + offset1);
+    device const T14 * y4 = (device const T14 *) (src1 + offset1);
+
+    // pointers to src0 rows
+    device const T0  * ax [NR0];
+    device const T04 * ax4[NR0];
+    FOR_UNROLL (short row = 0; row < NR0; ++row) {
+        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+        ax [row] = (device const T0  *) ((device char *) src0 + offset0);
+        ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
+    }
+
+    float sumf[NR0] = { 0.f };
+
+    const short ix = tiisg/(NW/NF);
+    const short il = tiisg%(NW/NF);
+
+    const int ib0 = sgitg*NF + ix;
+
+    T14 yl4[NF4];
+
+    device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
+
+    for (int ib = ib0; ib < nb; ib += NSG*NF) {
+        for (short i = 0; i < NF4; ++i) {
+            yl4[i] = yb4[i];
+        }
+
+        for (short row = 0; row < NR0; row++) {
+            device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
+
+            float sumq = 0.f;
+            FOR_UNROLL (short i = 0; i < NF4; ++i) {
+                sumq += dot(float4(xb4[i]), float4(yl4[i]));
             }
+
+            sumf[row] += sumq;
         }
+
+        yb4 += NSG*NF*NW/4;
     }
+
+    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
+        for (short row = 0; row < NR0; row++) {
+            sumf[row] += ax[row][i] * y[i];
+        }
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
-template<typename T0, typename T04, typename T1, typename T14>
-kernel void kernel_mul_mv(
+template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
+kernel void kernel_mul_mv_t_t_4(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
-        args,
-        src0,
-        src1,
-        dst,
-        tgpig,
-        tiisg);
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
+typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
 
-template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>;
-template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>;
-template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>;
+template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4,  N_R0_F, N_SIMDWIDTH>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float,  float4>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
+template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4,  N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
 #endif
 
+#define N_MV_T_T 4
+
 template<typename T04, typename T14, typename args_t>
 void kernel_mul_mv_c4_impl(
         args_t args,
@@ -3562,112 +3679,10 @@ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
 
 template [[host_name("kernel_mul_mv_f32_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<float4,  float4>;
 template [[host_name("kernel_mul_mv_f16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   float4>;
+template [[host_name("kernel_mul_mv_f16_f16_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   half4>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
-#endif
-
-template<typename T, typename T4>
-kernel void kernel_mul_mv_1row(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]]) {
-
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const uint i12 = im%args.ne12;
-    const uint i13 = im/args.ne12;
-
-    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
-
-    device const T     * x = (device const T     *) (src0 + offset0);
-    device const float * y = (device const float *) (src1 + offset1);
-
-    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
-
-    float sumf = 0;
-    if (args.ne00 < 128) {
-        for (int i = tiisg; i < args.ne00; i += 32) {
-            sumf += (float) x[i] * (float) y[i];
-        }
-        float sum_all = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst_f32[r0] = sum_all;
-        }
-    } else {
-        device const T4     * x4 = (device const T4     *) x;
-        device const float4 * y4 = (device const float4 *) y;
-
-        for (int i = tiisg; i < args.ne00/4; i += 32) {
-            sumf += dot((float4) x4[i], y4[i]);
-        }
-
-        float sum_all = simd_sum(sumf);
-
-        if (tiisg == 0) {
-            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
-            dst_f32[r0] = sum_all;
-        }
-    }
-}
-
-typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
-#endif
-
-// Assumes row size (ne00) is a multiple of 4
-template<typename T, typename T4>
-kernel void kernel_mul_mv_l4(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]]) {
-
-    const int nrows = args.ne11;
-    const int r0 = tgpig.x;
-    const int im = tgpig.z;
-
-    const uint i12 = im%args.ne12;
-    const uint i13 = im/args.ne12;
-
-    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-
-    device const T4 * x4 = (device const T4 *) (src0 + offset0);
-
-    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
-
-    for (int r1 = 0; r1 < nrows; ++r1) {
-        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
-
-        device const float4 * y4 = (device const float4 *) (src1 + offset1);
-
-        float sumf = 0;
-        for (int i = tiisg; i < args.ne00/4; i += 32) {
-            sumf += dot((float4) x4[i], y4[i]);
-        }
-
-        float sum_all = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
-        }
-    }
-}
-
-typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
+template [[host_name("kernel_mul_mv_bf16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
+template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
 #endif
 
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -5951,7 +5966,7 @@ kernel void kernel_concat(
     }
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_q2_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -5961,13 +5976,15 @@ void kernel_mul_mv_q2_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6051,10 +6068,10 @@ kernel void kernel_mul_mv_q2_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_q3_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6064,6 +6081,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
 
@@ -6071,7 +6089,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6215,10 +6233,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_q4_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6228,6 +6246,8 @@ void kernel_mul_mv_q4_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
+
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
@@ -6243,7 +6263,7 @@ void kernel_mul_mv_q4_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6337,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_q5_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6350,6 +6370,7 @@ void kernel_mul_mv_q5_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
 
@@ -6357,7 +6378,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6468,10 +6489,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_q6_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6481,6 +6502,7 @@ void kernel_mul_mv_q6_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -6493,7 +6515,7 @@ void kernel_mul_mv_q6_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6577,12 +6599,12 @@ kernel void kernel_mul_mv_q6_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq2_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6592,13 +6614,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6685,10 +6709,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq2_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6698,13 +6722,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6802,10 +6828,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq3_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6815,13 +6841,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -6912,10 +6940,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq3_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -6925,13 +6953,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7022,10 +7052,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq2_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7035,13 +7065,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7133,10 +7165,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq1_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7146,13 +7178,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7230,10 +7264,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq1_m_f32_impl(
         args_t args,
         device const char * src0,
@@ -7243,6 +7277,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     const int nb = args.ne00/QK_K;
 
@@ -7250,7 +7285,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7338,10 +7373,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -7351,6 +7386,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK4_NL;
@@ -7359,7 +7395,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7444,10 +7480,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7457,13 +7493,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7549,10 +7587,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
 void kernel_mul_mv_mxfp4_f32_impl(
         args_t args,
         device const char * src0,
@@ -7562,6 +7600,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
+    const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK_MXFP4;
@@ -7570,7 +7609,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7638,7 +7677,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -8314,7 +8353,7 @@ void mmv_fn(
     impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
 
 template<mul_mv_impl_fn_t impl_fn>
 kernel void kernel_mul_mv_id(
@@ -8379,36 +8418,44 @@ kernel void kernel_mul_mv_id(
         sgitg);
 }
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
+
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
 
-template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half,  float, N_R0_F, N_SIMDWIDTH>>>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
 #endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K,    N_SG_Q2_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K,    N_SG_Q3_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K,    N_SG_Q4_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K,    N_SG_Q5_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K,    N_SG_Q6_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S,   N_SG_IQ1_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M,   N_SG_IQ1_M,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS,  N_SG_IQ2_XS,  N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S,   N_SG_IQ3_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S,   N_SG_IQ2_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL,  N_SG_IQ4_NL,  N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS,  N_SG_IQ4_XS,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half,  half4,  float, float4, N_R0_F, N_SIMDWIDTH>>>;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+#endif
+
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS,  N_SIMDWIDTH>>>;
 
 kernel void kernel_pool_2d_max_f32(
         constant    ggml_metal_kargs_pool_2d & args,