}
}
-template<typename block_q_type, short NR0, short NSG, short NW, typename args_t>
+constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
+
+template<typename block_q_type, short NR0, short NW, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
+
constexpr short NQ = 16;
const int nb = args.ne00/QK4_0;
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<short NR0, short NSG, short NW, typename args_t>
+template<short NR0, short NW, typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
+
constexpr short NQ = 8;
const int nb = args.ne00/QK8_0;
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
-#define N_MV_T_T 4
-
-template<typename T0, typename T04, typename T1, typename T14, typename args_t>
-void kernel_mul_mv_impl(
+template<typename T0, typename T1, short NR0, short NW, typename args_t>
+void kernel_mul_mv_t_t_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem,
uint3 tgpig,
- ushort tiisg) {
- const int r0 = tgpig.x;
- const int rb = tgpig.y*N_MV_T_T;
+ ushort tiisg,
+ ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
+
+ constexpr short NB = 32;
+ constexpr short NF = 8;
+
+ const int nb = args.ne00/NB;
+
+ const int r0 = tgpig.x*NR0;
+ const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const T0 * x = (device const T0 *) (src0 + offset0);
+ //device const T0 * x = (device const T0 *) (src0 + offset0);
+ device const T1 * y = (device const T1 *) (src1 + offset1);
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+ // pointers to src0 rows
+ device const T0 * ax [NR0];
+ FOR_UNROLL (short row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- if (args.ne00 < 128) {
- for (int row = 0; row < N_MV_T_T; ++row) {
- int r1 = rb + row;
- if (r1 >= args.ne11) {
- break;
- }
+ ax[row] = (device const T0 *) ((device char *) src0 + offset0);
+ }
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ float sumf[NR0] = { 0.f };
- device const T1 * y = (device const T1 *) (src1 + offset1);
+ const short ix = tiisg/(NW/NF);
+ const short il = tiisg%(NW/NF);
- float sumf = 0;
- for (int i = tiisg; i < args.ne00; i += 32) {
- sumf += (T0) x[i] * (T1) y[i];
- }
+ const int ib0 = sgitg*NF + ix;
- float sum_all = simd_sum(sumf);
- if (tiisg == 0) {
- dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
- }
+ T1 yl[NF];
+
+ device const T1 * yb = y + (ib0*NB + il*NF);
+
+ for (int ib = ib0; ib < nb; ib += NSG*NF) {
+ for (short i = 0; i < NF; ++i) {
+ yl[i] = yb[i];
}
- } else {
- device const T04 * x4 = (device const T04 *) x;
- for (int row = 0; row < N_MV_T_T; ++row) {
- int r1 = rb + row;
- if (r1 >= args.ne11) {
- break;
+
+ for (short row = 0; row < NR0; row++) {
+ device const T0 * xb = ax[row] + (ib*NB + il*NF);
+
+ float sumq = 0.f;
+ FOR_UNROLL (short i = 0; i < NF; ++i) {
+ sumq += xb[i] * yl[i];
}
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ sumf[row] += sumq;
+ }
+
+ yb += NSG*NF*NW;
+ }
- device const T1 * y = (device const T1 *) (src1 + offset1);
- device const T14 * y4 = (device const T14 *) y;
+ for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
+ for (short row = 0; row < NR0; row++) {
+ sumf[row] += ax[row][i] * y[i];
+ }
+ }
- float sumf = 0;
- for (int i = tiisg; i < args.ne00/4; i += 32) {
- sumf += dot((float4) x4[i], (float4) y4[i]);
- }
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- float sum_all = simd_sum(sumf);
- if (tiisg == 0) {
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
- dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
+ helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+}
+
+template<typename T0, typename T1, short NR0, short NW>
+kernel void kernel_mul_mv_t_t(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
+#endif
+
+template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
+void kernel_mul_mv_t_t_4_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
+
+ constexpr short NB = 32;
+ constexpr short NF = 16;
+ constexpr short NF4 = NF/4;
+
+ const int nb = args.ne00/NB;
+
+ const int r0 = tgpig.x*NR0;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const T1 * y = (device const T1 *) (src1 + offset1);
+ device const T14 * y4 = (device const T14 *) (src1 + offset1);
+
+ // pointers to src0 rows
+ device const T0 * ax [NR0];
+ device const T04 * ax4[NR0];
+ FOR_UNROLL (short row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ ax [row] = (device const T0 *) ((device char *) src0 + offset0);
+ ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
+ }
+
+ float sumf[NR0] = { 0.f };
+
+ const short ix = tiisg/(NW/NF);
+ const short il = tiisg%(NW/NF);
+
+ const int ib0 = sgitg*NF + ix;
+
+ T14 yl4[NF4];
+
+ device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
+
+ for (int ib = ib0; ib < nb; ib += NSG*NF) {
+ for (short i = 0; i < NF4; ++i) {
+ yl4[i] = yb4[i];
+ }
+
+ for (short row = 0; row < NR0; row++) {
+ device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
+
+ float sumq = 0.f;
+ FOR_UNROLL (short i = 0; i < NF4; ++i) {
+ sumq += dot(float4(xb4[i]), float4(yl4[i]));
}
+
+ sumf[row] += sumq;
}
+
+ yb4 += NSG*NF*NW/4;
}
+
+ for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
+ for (short row = 0; row < NR0; row++) {
+ sumf[row] += ax[row][i] * y[i];
+ }
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
-template<typename T0, typename T04, typename T1, typename T14>
-kernel void kernel_mul_mv(
+template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
+kernel void kernel_mul_mv_t_t_4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
- args,
- src0,
- src1,
- dst,
- tgpig,
- tiisg);
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
+typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
-template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
-template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
-template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
+template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>;
#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
+template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>;
+template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
#endif
+#define N_MV_T_T 4
+
template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
args_t args,
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
+template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
-#endif
-
-template<typename T, typename T4>
-kernel void kernel_mul_mv_1row(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]]) {
-
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
-
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
-
- device const T * x = (device const T *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
-
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
-
- float sumf = 0;
- if (args.ne00 < 128) {
- for (int i = tiisg; i < args.ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
- }
- float sum_all = simd_sum(sumf);
- if (tiisg == 0) {
- dst_f32[r0] = sum_all;
- }
- } else {
- device const T4 * x4 = (device const T4 *) x;
- device const float4 * y4 = (device const float4 *) y;
-
- for (int i = tiisg; i < args.ne00/4; i += 32) {
- sumf += dot((float4) x4[i], y4[i]);
- }
-
- float sum_all = simd_sum(sumf);
-
- if (tiisg == 0) {
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
- dst_f32[r0] = sum_all;
- }
- }
-}
-
-typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
-#endif
-
-// Assumes row size (ne00) is a multiple of 4
-template<typename T, typename T4>
-kernel void kernel_mul_mv_l4(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]]) {
-
- const int nrows = args.ne11;
- const int r0 = tgpig.x;
- const int im = tgpig.z;
-
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
-
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-
- device const T4 * x4 = (device const T4 *) (src0 + offset0);
-
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
-
- for (int r1 = 0; r1 < nrows; ++r1) {
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
-
- device const float4 * y4 = (device const float4 *) (src1 + offset1);
-
- float sumf = 0;
- for (int i = tiisg; i < args.ne00/4; i += 32) {
- sumf += dot((float4) x4[i], y4[i]);
- }
-
- float sum_all = simd_sum(sumf);
- if (tiisg == 0) {
- dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
- }
- }
-}
-
-typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
+template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
+template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
#endif
static float rope_yarn_ramp(const float low, const float high, const int i0) {
}
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
+
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, int nsg, int nw, typename args_t>
+template<int nr0, int nw, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
sgitg);
}
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
+
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
-template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SIMDWIDTH>>>;
#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
#endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
-
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
+#endif
+
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SIMDWIDTH>>>;
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,