]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : use function constants for mul_mv_ext kernels (#16074)
authorGeorgi Gerganov <redacted>
Thu, 18 Sep 2025 13:28:41 +0000 (16:28 +0300)
committerGitHub <redacted>
Thu, 18 Sep 2025 13:28:41 +0000 (16:28 +0300)
* metal : use function constants for mul_mv_ext kernels

ggml-ci

* metal : remove NW template argument

ggml-ci

* metal : adjust constants

ggml-ci

ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index bada84cef31874928d9db81c978fb4d774bf085b..fe015afc54aa9fe0a791db63eb53f1069af04ac9 100644 (file)
@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
-    snprintf(name, 256, "%s", base);
+    snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
         return res;
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+    ggml_metal_cv_set_int16(cv, nsg,   FC_MUL_MV + 0);
+    ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
 
     return res;
 }
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
     };
 
     snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
-    snprintf(name, 256, "%s", base);
+    snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
     };
 
     snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
-    snprintf(name, 256, "%s", base);
+    snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
             dk,
             dv);
 
-    snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
-            "flash_attn_ext",
-            ggml_type_name(op->src[1]->type),
-            dk,
-            dv,
+    snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
+            base,
             has_mask,
             has_sinks,
             has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
             dk,
             dv);
 
-    snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
-            "flash_attn_ext_vec",
-            ggml_type_name(op->src[1]->type),
-            dk,
-            dv,
+    snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
+            base,
             has_mask,
             has_sinks,
             has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
     char name[256];
 
     snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
-    snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
+    snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
index 4a3a819fc02c88f32f56dc57b57ea680df4fb763..044d6953f6779d37e1c1a7291629125f23d54a6c 100644 (file)
@@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max          (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0    (ggml_metal_library_t lib, int ne02, int ne20);
index b2572965405b86d2ecbc095d64a659c65ae6fa90..3a7a4f317c638aa54ddf6e013742650842f4716d 100644 (file)
 #define N_R0_Q3_K 2
 #define N_SG_Q3_K 2
 
-#define N_R0_Q4_K 4
+#define N_R0_Q4_K 2
 #define N_SG_Q4_K 2
 
 #define N_R0_Q5_K 2
 #define N_SG_Q5_K 2
 
-#define N_R0_Q6_K 1
+#define N_R0_Q6_K 2
 #define N_SG_Q6_K 2
 
 #define N_R0_IQ1_S 4
@@ -374,9 +374,6 @@ typedef struct {
     int32_t  ne1;
     int16_t  r2;
     int16_t  r3;
-    int16_t  nsg;
-    int16_t  nxpsg;
-    int16_t  r1ptg;
 } ggml_metal_kargs_mul_mv_ext;
 
 typedef struct {
index a28206e78683ce20e7335d554adf40a9f6c1822b..04665b3d6dbb6f55a244db3d19cdd3d0c9a040ec 100644 (file)
@@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
                 GGML_ABORT("unsupported ne11");
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
+        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
 
         ggml_metal_kargs_mul_mv_ext args = {
             /*.ne00  =*/ ne00,
@@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
             /*.ne1   =*/ ne1,
             /*.r2    =*/ r2,
             /*.r3    =*/ r3,
-            /*.nsg   =*/ nsg,
-            /*.nxpsg =*/ nxpsg,
-            /*.r1ptg =*/ r1ptg,
         };
 
         ggml_metal_encoder_set_pipeline(enc, pipeline);
index 20ceb1fe04913f390deabfc486feb112f8133272..c7d97ba70b4c9a644127c5361c590ec887f2052f 100644 (file)
@@ -2843,7 +2843,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
-template<short NR0, short NW>
+template<short NR0>
 static inline void helper_mv_reduce_and_write(
         device float * dst_f32,
         float sumf[NR0],
@@ -2852,6 +2852,8 @@ static inline void helper_mv_reduce_and_write(
         ushort tiisg,
         ushort sgitg,
         threadgroup char * shmem) {
+    constexpr short NW = N_SIMDWIDTH;
+
     threadgroup float * shmem_f32[NR0];
 
     for (short row = 0; row < NR0; ++row) {
@@ -2883,9 +2885,10 @@ static inline void helper_mv_reduce_and_write(
     }
 }
 
-constant short FC_mul_mv_nsg  [[function_constant(FC_MUL_MV + 0)]];
+constant short FC_mul_mv_nsg   [[function_constant(FC_MUL_MV + 0)]];
+constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
 
-template<typename block_q_type, short NR0, short NW, typename args_t>
+template<typename block_q_type, short NR0, typename args_t>
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -2897,6 +2900,7 @@ void mul_vec_q_n_f32_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
+    constexpr short NW = N_SIMDWIDTH;
     constexpr short NQ = 16;
 
     const int nb = args.ne00/QK4_0;
@@ -2961,7 +2965,7 @@ void mul_vec_q_n_f32_impl(
 
     device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
 
-    //helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+    //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 
     for (int row = 0; row < NR0; ++row) {
         const float tot = simd_sum(sumf[row]);
@@ -2981,7 +2985,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -2993,7 +2997,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -3005,7 +3009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -3017,10 +3021,10 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<short NR0, short NW, typename args_t>
+template<short NR0, typename args_t>
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -3032,6 +3036,7 @@ void kernel_mul_mv_q8_0_f32_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
+    constexpr short NW = N_SIMDWIDTH;
     constexpr short NQ = 8;
 
     const int nb = args.ne00/QK8_0;
@@ -3090,7 +3095,7 @@ void kernel_mul_mv_q8_0_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
 [[host_name("kernel_mul_mv_q8_0_f32")]]
@@ -3103,12 +3108,12 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
 // chpb - chunks per quantization block
-template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
+template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
 void kernel_mul_mv_ext_q4_f32_impl(
         constant ggml_metal_kargs_mul_mv_ext & args,
         device const char * src0,
@@ -3117,6 +3122,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short NSG   = FC_mul_mv_nsg;
+    const short nxpsg = FC_mul_mv_nxpsg;
+
     const short chpt = 4; // chunks per thread
 
   //const short nxpsg = (32);
@@ -3125,7 +3133,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
     const short tx = tiisg%nxpsg;
     const short ty = tiisg/nxpsg;
 
-    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
     const int i11 = tgpig.y*r1ptg;
     const int i1m = tgpig.z;
 
@@ -3208,7 +3216,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
 }
 
 // mat-vec kernel processing in chunks of float4x4
-template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
+template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
 void kernel_mul_mv_ext_q4x4_f32_impl(
         constant ggml_metal_kargs_mul_mv_ext & args,
         device const char * src0,
@@ -3217,6 +3225,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short NSG   = FC_mul_mv_nsg;
+    const short nxpsg = FC_mul_mv_nxpsg;
+
     const short chpt = 1;
 
   //const short nxpsg = (32);
@@ -3225,7 +3236,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
     const short tx = tiisg%nxpsg;
     const short ty = tiisg/nxpsg;
 
-    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
     const int i11 = tgpig.y*r1ptg;
     const int i1m = tgpig.z;
 
@@ -3322,12 +3333,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-    switch (args.nxpsg) {
-        case 4:  kernel_mul_mv_ext_q4_f32_impl<4,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 8:  kernel_mul_mv_ext_q4_f32_impl<8,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-    }
+    kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
 }
 
 template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
@@ -3339,12 +3345,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-    switch (args.nxpsg) {
-        case 4:  kernel_mul_mv_ext_q4x4_f32_impl<4,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 8:  kernel_mul_mv_ext_q4x4_f32_impl<8,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-        case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
-    }
+    kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
 }
 
 typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
@@ -3410,7 +3411,7 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
 
-template<typename T0, typename T1, short NR0, short NW, typename args_t>
+template<typename T0, typename T1, short NR0, typename args_t>
 void kernel_mul_mv_t_t_impl(
         args_t args,
         device const char * src0,
@@ -3422,6 +3423,7 @@ void kernel_mul_mv_t_t_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
+    constexpr short NW = N_SIMDWIDTH;
     constexpr short NB = 32;
     constexpr short NF = 8;
 
@@ -3486,10 +3488,10 @@ void kernel_mul_mv_t_t_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
-template<typename T0, typename T1, short NR0, short NW>
+template<typename T0, typename T1, short NR0>
 kernel void kernel_mul_mv_t_t(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
@@ -3499,20 +3501,20 @@ kernel void kernel_mul_mv_t_t(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
+typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
 
-template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float, N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half,  N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float, N_R0_F>;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half,  N_R0_F>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float,  N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float,  N_R0_F>;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
 #endif
 
-template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
+template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
 void kernel_mul_mv_t_t_4_impl(
         args_t args,
         device const char * src0,
@@ -3524,6 +3526,7 @@ void kernel_mul_mv_t_t_4_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
+    constexpr short NW = N_SIMDWIDTH;
     constexpr short NB  = 32;
     constexpr short NF  = 16;
     constexpr short NF4 = NF/4;
@@ -3591,10 +3594,10 @@ void kernel_mul_mv_t_t_4_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
-template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
+template<typename T0, typename T04, typename T1, typename T14, short NR0>
 kernel void kernel_mul_mv_t_t_4(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
@@ -3604,17 +3607,17 @@ kernel void kernel_mul_mv_t_t_4(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
+typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
 
-template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4, N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4,  N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
+template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4, N_R0_F>;
+template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4,  N_R0_F>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4,  N_R0_F, N_SIMDWIDTH>;
-template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4,  N_R0_F>;
+template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
 #endif
 
 #define N_MV_T_T 4
@@ -5966,7 +5969,7 @@ kernel void kernel_concat(
     }
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_q2_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6068,10 +6071,10 @@ kernel void kernel_mul_mv_q2_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_q3_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6233,10 +6236,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_q4_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6248,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
+    constexpr uint16_t kmask1 = 0x3f3f;
+    constexpr uint16_t kmask2 = 0x0f0f;
+    constexpr uint16_t kmask3 = 0xc0c0;
 
     const short ix = tiisg/8;  // 0...3
     const short it = tiisg%8;  // 0...7
@@ -6309,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
 
-            for (short i = 0; i < 4; ++i) {
+            FOR_UNROLL (short i = 0; i < 4; ++i) {
                 acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
                 acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
                 acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
@@ -6320,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
                 acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
             }
 
-            float dall = dh[0];
-            float dmin = dh[1];
-
-            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
-                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
-                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
-                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
-                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+            sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
             q1 += args.nb01/2;
             sc += args.nb01/2;
@@ -6357,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_q5_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6393,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     float yl[16], yh[16];
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
+    constexpr uint16_t kmask1 = 0x3f3f;
+    constexpr uint16_t kmask2 = 0x0f0f;
+    constexpr uint16_t kmask3 = 0xc0c0;
 
     const short tid = tiisg/4;
     const short ix  = tiisg%4;
@@ -6441,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
             float4 acc1 = {0.f};
             float4 acc2 = {0.f};
-            for (short l = 0; l < 8; ++l) {
+            FOR_UNROLL (short l = 0; l < 8; ++l) {
                 uint8_t h = qh[l];
                 acc1[0] += yl[l+0] * (q1[l] & 0x0F);
                 acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -6452,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
                 acc2[2] += h & hm3 ? yh[l+0] : 0.f;
                 acc2[3] += h & hm4 ? yh[l+8] : 0.f;
             }
-            const float dall = dh[0];
-            const float dmin = dh[1];
-            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
-                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
-                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
-                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
-                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            sumf[row] += dh[0] * (sc8[0] * (acc1[0]      + 16.f*acc2[0]) +
+                                  sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                  sc8[4] * (acc1[2]      + 16.f*acc2[2]) +
+                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
             q1 += args.nb01;
             qh += args.nb01;
@@ -6489,10 +6488,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_q6_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -6504,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
         ushort sgitg) {
     const short NSG = FC_mul_mv_nsg;
 
-    const uint8_t kmask1 = 0x03;
-    const uint8_t kmask2 = 0x0C;
-    const uint8_t kmask3 = 0x30;
-    const uint8_t kmask4 = 0xC0;
+    constexpr uint8_t kmask1 = 0x03;
+    constexpr uint8_t kmask2 = 0x0C;
+    constexpr uint8_t kmask3 = 0x30;
+    constexpr uint8_t kmask4 = 0xC0;
 
     const int nb = args.ne00/QK_K;
 
@@ -6558,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
         }
 
         for (short row = 0; row < nr0; ++row) {
-            const float dall = dh[0];
-
             float4 sums = {0.f, 0.f, 0.f, 0.f};
 
-            for (short l = 0; l < 4; ++l) {
+            FOR_UNROLL (short l = 0; l < 4; ++l) {
                 sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
                 sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
                 sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
                 sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
             }
 
-            sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+            sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
             q1 += args.nb01;
             q2 += args.nb01;
@@ -6599,12 +6596,12 @@ kernel void kernel_mul_mv_q6_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq2_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6709,10 +6706,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq2_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6828,10 +6825,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq3_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -6940,10 +6937,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq3_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7052,10 +7049,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq2_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7165,10 +7162,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq1_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -7264,10 +7261,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq1_m_f32_impl(
         args_t args,
         device const char * src0,
@@ -7373,10 +7370,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -7480,10 +7477,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7587,10 +7584,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, int nw, typename args_t>
+template<int nr0, typename args_t>
 void kernel_mul_mv_mxfp4_f32_impl(
         args_t args,
         device const char * src0,
@@ -7677,7 +7674,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -8353,7 +8350,7 @@ void mmv_fn(
     impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
 
 template<mul_mv_impl_fn_t impl_fn>
 kernel void kernel_mul_mv_id(
@@ -8418,44 +8415,44 @@ kernel void kernel_mul_mv_id(
         sgitg);
 }
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
 
-template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half,  float, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half,  float, N_R0_F>>>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
 #endif
-template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half,  half4,  float, float4, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half,  half4,  float, float4, N_R0_F>>>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
 #endif
 
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K,    N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS,  N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S,   N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL,  N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
+
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K>>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K>>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K>>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K>>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K>>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S>>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S>>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S>>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
 
 kernel void kernel_pool_2d_max_f32(
         constant    ggml_metal_kargs_pool_2d & args,