]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : dynamic simdgroups for MV kernels (#16340)
authorGeorgi Gerganov <redacted>
Tue, 30 Sep 2025 08:03:23 +0000 (11:03 +0300)
committerGitHub <redacted>
Tue, 30 Sep 2025 08:03:23 +0000 (11:03 +0300)
* metal : dynamic simdgroups for MV kernels

* cont : minor

ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 0bf7fe9f9237a9a341e36500f7e9a1e6506f34d6..819f31c8a300cb0cc2c61789dab6c273817c20a6 100644 (file)
@@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
         case GGML_TYPE_F16:
         case GGML_TYPE_BF16:
             {
-                if (ne00 == 4) {
+                if (ne00 < 32) {
                     nsg = 1;
                     nr0 = 32;
-                    nr1 = 4;
-                    suffix = "_c4";
-                } else if (ne00 % 4 == 0) {
-                    nsg = N_SG_F;
-                    nr0 = N_R0_F;
                     nr1 = 1;
-                    smem = 32*sizeof(float)*N_R0_F;
-                    suffix = "_4";
+                    suffix = "_short";
                 } else {
-                    nsg = N_SG_F;
-                    nr0 = N_R0_F;
+                    nsg = std::min(4, (ne00 + 127) / 128);
+                    nr0 = 2;
                     nr1 = 1;
-                    smem = 32*sizeof(float)*N_R0_F;
+                    smem = 32*sizeof(float)*nr0;
+                    suffix = ne00 % 4 == 0 ? "_4" : "";
                 }
             } break;
         case GGML_TYPE_Q4_0:
@@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
         case GGML_TYPE_F16:
         case GGML_TYPE_BF16:
             {
-                if (ne00 % 4 == 0) {
-                    nsg = N_SG_F;
-                    nr0 = N_R0_F;
-                    nr1 = 1;
-                    smem = 32*sizeof(float)*N_R0_F;
-                    suffix = "_4";
-                } else {
-                    nsg = N_SG_F;
-                    nr0 = N_R0_F;
-                    nr1 = 1;
-                    smem = 32*sizeof(float)*N_R0_F;
-                }
+                nsg = std::min(4, (ne00 + 127) / 128);
+                nr0 = 2;
+                nr1 = 1;
+                smem = 32*sizeof(float)*nr0;
+                suffix = ne00 % 4 == 0 ? "_4" : "";
             } break;
         case GGML_TYPE_Q4_0:
             {
index d355c6dfc7526371dd76143da5c891faa365c058..88c98423ebec0a9fc06001b0b5d923e04b3b32ed 100644 (file)
@@ -8,9 +8,6 @@
 //
 // TODO: for optimal performance, become function of the device and work size
 
-#define N_R0_F 2
-#define N_SG_F 4
-
 #define N_R0_Q4_0 4
 #define N_SG_Q4_0 2
 
@@ -352,6 +349,7 @@ typedef struct {
     uint64_t nb13;
     int32_t  ne0;
     int32_t  ne1;
+    int32_t  nr0;
     int16_t  r2;
     int16_t  r3;
 } ggml_metal_kargs_mul_mv;
@@ -427,6 +425,7 @@ typedef struct {
     int32_t  ne0;
     int32_t  ne1;
     uint64_t nb1;
+    int32_t  nr0;
 } ggml_metal_kargs_mul_mv_id;
 
 // NORM
index d7267a6aedfff49aea434fe8243c6d834005e3e0..e85a223c01dc38cde843fc56b5a03e970d8c69f0 100644 (file)
@@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
     } else {
         ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
 
+        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
+        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
+        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
+
+        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+
         ggml_metal_kargs_mul_mv args = {
             /*.ne00 =*/ ne00,
             /*.ne01 =*/ ne01,
@@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
             /*.nb13 =*/ nb13,
             /*.ne0  =*/ ne0,
             /*.ne1  =*/ ne1,
+            /*.nr0  =*/ nr0,
             /*.r2   =*/ r2,
             /*.r3   =*/ r3,
         };
 
-        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
-        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
-        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
-
-        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
-
         ggml_metal_encoder_set_pipeline(enc, pipeline);
         ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
         ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
         }
     } else {
+        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
+
+        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
+        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
+        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
+
+        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+
         ggml_metal_kargs_mul_mv_id args = {
             /*.nei0 =*/ ne20,
             /*.nei1 =*/ ne21,
@@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
             /*.ne0  =*/ ne0,
             /*.ne1  =*/ ne1,
             /*.nb1  =*/ nb1,
+            /*.nr0  =*/ nr0,
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
-
-        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
-        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
-        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
-
-        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
-
         if (ggml_is_quantized(op->src[0]->type)) {
             GGML_ASSERT(ne00 >= nsg*nr0);
         }
index 0271fd5b25d7357a4114b3a1eeb8cabbc62ccaf8..96df6f0ce62de18fd0e3c66aad0487a8a39d9725 100644 (file)
@@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl(
     helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
-template<typename T0, typename T1, short NR0>
+template<typename T0, typename T1, typename args_t>
+void kernel_mul_mv_t_t_disp(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    switch (args.nr0) {
+      //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+        case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+      //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+      //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+    }
+}
+
+template<typename T0, typename T1>
 kernel void kernel_mul_mv_t_t(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
@@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
+typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
 
-template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
-template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float, N_R0_F>;
-template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half,  N_R0_F>;
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float>;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float,  N_R0_F>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
+template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
 #endif
 
 template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
@@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl(
     helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
 }
 
-template<typename T0, typename T04, typename T1, typename T14, short NR0>
+template<typename T0, typename T04, typename T1, typename T14, typename args_t>
+void kernel_mul_mv_t_t_4_disp(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    switch (args.nr0) {
+      //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+        case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+      //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+      //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
+    };
+}
+
+template<typename T0, typename T04, typename T1, typename T14>
 kernel void kernel_mul_mv_t_t_4(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
@@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
+typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
 
-template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
-template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4, N_R0_F>;
-template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4,  N_R0_F>;
+template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
+template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4>;
+template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4,  N_R0_F>;
-template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
+template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4>;
+template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
 #endif
 
-#define N_MV_T_T 4
-
-template<typename T04, typename T14, typename args_t>
-void kernel_mul_mv_c4_impl(
+template<typename T0, typename T1, typename args_t>
+void kernel_mul_mv_t_t_short_impl(
         args_t args,
         device const char * src0,
         device const char * src1,
@@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl(
         uint3  tgpig,
         ushort tiisg) {
     const int r0 = tgpig.x*32 + tiisg;
-    const int rb = tgpig.y*N_MV_T_T;
+    const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     if (r0 >= args.ne01) {
@@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl(
 
     const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
-    device const T04 * x = (device const T04 *) (src0 + offset0);
+    device const T0 * x = (device const T0 *) (src0 + offset0);
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
 
-    for (int row = 0; row < N_MV_T_T; ++row) {
-        int r1 = rb + row;
-        if (r1 >= args.ne11) {
-            break;
-        }
+    const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
 
-        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+    device const T1 * y = (device const T1 *) (src1 + offset1);
 
-        device const T14 * y = (device const T14 *) (src1 + offset1);
+    float res = 0.0f;
 
-        dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
+    for (int i = 0; i < args.ne00; ++i) {
+        res += (float) x[i] * (float) y[i];
     }
+
+    dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
 }
 
-template<typename T04, typename T14>
-kernel void kernel_mul_mv_c4(
+template<typename T0, typename T1>
+kernel void kernel_mul_mv_t_t_short(
         constant ggml_metal_kargs_mul_mv & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
+    kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
         args,
         src0,
         src1,
@@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4(
         tiisg);
 }
 
-typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
+typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
 
-template [[host_name("kernel_mul_mv_f32_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<float4,  float4>;
-template [[host_name("kernel_mul_mv_f16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   float4>;
-template [[host_name("kernel_mul_mv_f16_f16_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   half4>;
+template [[host_name("kernel_mul_mv_f32_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
+template [[host_name("kernel_mul_mv_f16_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  float>;
+template [[host_name("kernel_mul_mv_f16_f16_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  half>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
-template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
+template [[host_name("kernel_mul_mv_bf16_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
+template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
 #endif
 
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]]  kernel mul_mm_id kernel_m
 // matrix-vector multiplication
 //
 
-typedef void (kernel_mul_mv_impl_t)(
+typedef void (kernel_mul_mv_disp_t)(
         ggml_metal_kargs_mul_mv args,
         device const char * src0,
         device const char * src1,
@@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)(
         uint3  tgpig,
         ushort tiisg);
 
-typedef void (kernel_mul_mv2_impl_t)(
+typedef void (kernel_mul_mv2_disp_t)(
         ggml_metal_kargs_mul_mv args,
         device const char * src0,
         device const char * src1,
@@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)(
         ushort tiisg,
         ushort sgitg);
 
-template<kernel_mul_mv_impl_t impl_fn>
+template<kernel_mul_mv_disp_t disp_fn>
 void mmv_fn(
         ggml_metal_kargs_mul_mv args,
         device const char * src0,
@@ -8487,10 +8520,10 @@ void mmv_fn(
         ushort tiitg,
         ushort tiisg,
         ushort sgitg) {
-    impl_fn(args, src0, src1, dst, tgpig, tiisg);
+    disp_fn(args, src0, src1, dst, tgpig, tiisg);
 }
 
-template<kernel_mul_mv2_impl_t impl_fn>
+template<kernel_mul_mv2_disp_t disp_fn>
 void mmv_fn(
         ggml_metal_kargs_mul_mv args,
         device const char * src0,
@@ -8501,12 +8534,12 @@ void mmv_fn(
         ushort tiitg,
         ushort tiisg,
         ushort sgitg) {
-    impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
 
-template<mul_mv_impl_fn_t impl_fn>
+template<mul_mv_disp_fn_t disp_fn>
 kernel void kernel_mul_mv_id(
         constant ggml_metal_kargs_mul_mv_id & args,
         device const char * src0s,
@@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id(
         /*.nb13 =*/ args.nb12, // ne12 == 1
         /*.ne0  =*/ args.ne0,
         /*.ne1  =*/ 1, // args.ne1,
+        /*.nr0  =*/ args.nr0,
         /*.r2   =*/ 1,
         /*.r3   =*/ 1,
     };
 
-    impl_fn(
+    disp_fn(
         args0,
         /* src0 */ src0_cur,
         /* src1 */ src1_cur,
@@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id(
         sgitg);
 }
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
 
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
 
-template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half,  float, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half,  float>>>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
 #endif
-template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half,  half4,  float, float4, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half,  half4,  float, float4>>>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
 #endif
 
 template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;