kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
-#define N_F32_F32 4
+#define N_MV_T_T 4
-void kernel_mul_mv_f32_f32_impl(
+template<typename T0, typename T04, typename T1, typename T14>
+void kernel_mul_mv_impl(
device const char * src0,
device const char * src1,
device float * dst,
uint64_t nb12,
int64_t ne0,
int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg) {
-
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
const int64_t r0 = tgpig.x;
- const int64_t rb = tgpig.y*N_F32_F32;
+ const int64_t rb = tgpig.y*N_MV_T_T;
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
- device const float * x = (device const float *) (src0 + offset0);
+ device const T0 * x = (device const T0 *) (src0 + offset0);
if (ne00 < 128) {
- for (int row = 0; row < N_F32_F32; ++row) {
+ for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
+ sumf += (T0) x[i] * (T1) y[i];
}
float all_sum = simd_sum(sumf);
}
}
} else {
- device const float4 * x4 = (device const float4 *)x;
- for (int row = 0; row < N_F32_F32; ++row) {
+ device const T04 * x4 = (device const T04 *) x;
+ for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- device const float4 * y4 = (device const float4 *) y;
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+ device const T14 * y4 = (device const T14 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
-[[host_name("kernel_mul_mv_f32_f32")]]
-kernel void kernel_mul_mv_f32_f32(
+template<typename T0, typename T04, typename T1, typename T14>
+kernel void kernel_mul_mv(
device const char * src0,
device const char * src1,
device float * dst,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+ kernel_mul_mv_impl<T0, T04, T1, T14>(
+ src0,
+ src1,
+ dst,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
}
-#define N_F16_F16 4
+typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
-kernel void kernel_mul_mv_f16_f16(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
+template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
+template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
- const int64_t r0 = tgpig.x;
- const int64_t rb = tgpig.y*N_F16_F16;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
-
- device const half * x = (device const half *) (src0 + offset0);
-
- if (ne00 < 128) {
- for (int row = 0; row < N_F16_F16; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
-
- float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (half) x[i] * (half) y[i];
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- } else {
- device const half4 * x4 = (device const half4 *)x;
- for (int row = 0; row < N_F16_F16; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
- device const half4 * y4 = (device const half4 *) y;
-
- float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- }
-}
-
-void kernel_mul_mv_f16_f32_1row_impl(
+template<typename T, typename T4>
+kernel void kernel_mul_mv_1row(
device const char * src0,
device const char * src1,
device float * dst,
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
- device const half * x = (device const half *) (src0 + offset0);
+ device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} else {
- device const half4 * x4 = (device const half4 *) x;
+ device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
+
for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
}
+
float all_sum = simd_sum(sumf);
+
if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
-[[host_name("kernel_mul_mv_f16_f32_1row")]]
-kernel void kernel_mul_mv_f16_f32_1row(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
-}
-
-#define N_F16_F32 4
-
-void kernel_mul_mv_f16_f32_impl(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg) {
-
- const int64_t r0 = tgpig.x;
- const int64_t rb = tgpig.y*N_F16_F32;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
-
- device const half * x = (device const half *) (src0 + offset0);
-
- if (ne00 < 128) {
- for (int row = 0; row < N_F16_F32; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
-
- float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- } else {
- device const half4 * x4 = (device const half4 *)x;
- for (int row = 0; row < N_F16_F32; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- device const float4 * y4 = (device const float4 *) y;
-
- float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
- }
+typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- }
-}
-
-[[host_name("kernel_mul_mv_f16_f32")]]
-kernel void kernel_mul_mv_f16_f32(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
-}
+template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
// Assumes row size (ne00) is a multiple of 4
-kernel void kernel_mul_mv_f16_f32_l4(
+template<typename T, typename T4>
+kernel void kernel_mul_mv_l4(
device const char * src0,
device const char * src1,
device float * dst,
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
- device const half4 * x4 = (device const half4 *) (src0 + offset0);
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
}
float all_sum = simd_sum(sumf);
}
}
+typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
-kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0];
- }
-}
-
-kernel void kernel_cpy_f16_f32(
- device const half * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0];
- }
-}
-
-kernel void kernel_cpy_f32_f16(
- device const float * src0,
- device half * dst,
+template<typename T0, typename T1>
+kernel void kernel_cpy(
+ device const void * src0,
+ device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- dst_data[i00] = src[0];
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = (T1) src[0];
}
}
-kernel void kernel_cpy_f32_f32(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- dst_data[i00] = src[0];
- }
-}
+template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
+template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
+template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
+template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
kernel void kernel_cpy_f32_q8_0(
device const float * src0,
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
-kernel void kernel_get_rows(
+kernel void kernel_get_rows_q(
device const void * src0,
- device const char * src1,
+ device const void * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
- //const int64_t i = tgpig;
- //const int64_t r = ((device int32_t *) src1)[i];
-
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
- dequantize_func(
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
}
}
-kernel void kernel_get_rows_f32(
- device const void * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int64_t i10 = tgpig.x;
- const int64_t i11 = tgpig.y;
-
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
-
- const int64_t i02 = i11;
-
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
- }
-}
-
-kernel void kernel_get_rows_f16(
+template<typename T>
+kernel void kernel_get_rows_f(
device const void * src0,
- device const char * src1,
+ device const void * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
}
}
kernel void kernel_get_rows_i32(
device const void * src0,
- device const char * src1,
+ device const void * src1,
device int32_t * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
}
}
#define SG_MAT_ROW 8
// each block_q contains 16*nl weights
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-void kernel_mul_mm_impl(device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
- threadgroup half * sa = (threadgroup half *)(shared_memory);
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y;
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
- simdgroup_half8x8 ma[4];
+ simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
- half4x4 temp_a;
+ T4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4)
}
}
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm(device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
- src0,
- src1,
- dst,
- ne00,
- ne02,
- nb01,
- nb02,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- shared_memory,
- tgpig,
- tiitg,
- sgitg);
-}
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
device const uchar * src0s,
// get rows
//
-typedef void (get_rows_t)(
- device const void * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3, uint, uint3);
-
-//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
-//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
-template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
+
+typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// matrix-matrix multiplication
//
-typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// indirect matrix-matrix multiplication
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
}
-typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
sgitg);
}
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
-
-template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;