float4 sa_vec(0.0);
- for (int j = 0; j < head_size; j += 4) {
+ for (uint j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-// quantizations where the block size is 32. It also does not
-// guard against the number of rows not being divisible by
-// N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
+template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q_type * ax[nr];
- for (int row = 0; row < nr; ++row) {
+ device const block_q_type * ax[nr0];
+ for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
float yl[16]; // src1 vector cache
- float sumf[nr] = {0.f};
+ float sumf[nr0] = {0.f};
const short ix = (tiisg/2);
const short il = (tiisg%2)*8;
float sumy[2] = { 0.f, 0.f };
#pragma unroll
- for (int i = 0; i < 8; i += 2) {
+ for (short i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
yl[i + 0] = yb[i + 0];
yl[i + 1] = yb[i + 1]/256.f;
}
#pragma unroll
- for (int row = 0; row < nr; row++) {
+ for (short row = 0; row < nr0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
- for (int row = 0; row < nr; ++row) {
+ for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
#define NB_Q8_0 8
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
- const int nr = N_DST;
- const int nsg = N_SIMDGROUP;
- const int nw = N_SIMDWIDTH;
-
const int nb = args.ne00/QK8_0;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0*nsg + sgitg)*nr;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q8_0 * ax[nr];
- for (int row = 0; row < nr; ++row) {
+ device const block_q8_0 * ax[nr0];
+ for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
float yl[NB_Q8_0];
- float sumf[nr] = { 0.f };
+ float sumf[nr0] = { 0.f };
const short ix = tiisg/4;
const short il = tiisg%4;
yl[i] = yb[i];
}
- for (int row = 0; row < nr; row++) {
+ for (short row = 0; row < nr0; row++) {
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
float sumq = 0.f;
for (short iq = 0; iq < NB_Q8_0; ++iq) {
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr; ++row) {
+ for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
sumf += (T0) x[i] * (T1) y[i];
}
- float all_sum = simd_sum(sumf);
+ float sum_all = simd_sum(sumf);
if (tiisg == 0) {
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
} else {
sumf += dot((float4) x4[i], (float4) y4[i]);
}
- float all_sum = simd_sum(sumf);
+ float sum_all = simd_sum(sumf);
if (tiisg == 0) {
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
}
for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
- float all_sum = simd_sum(sumf);
+ float sum_all = simd_sum(sumf);
if (tiisg == 0) {
- dst_f32[r0] = all_sum;
+ dst_f32[r0] = sum_all;
}
} else {
device const T4 * x4 = (device const T4 *) x;
sumf += dot((float4) x4[i], y4[i]);
}
- float all_sum = simd_sum(sumf);
+ float sum_all = simd_sum(sumf);
if (tiisg == 0) {
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst_f32[r0] = all_sum;
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+ dst_f32[r0] = sum_all;
}
}
}
sumf += dot((float4) x4[i], y4[i]);
}
- float all_sum = simd_sum(sumf);
+ float sum_all = simd_sum(sumf);
if (tiisg == 0) {
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
}
float amax = 0.0f; // absolute max
float max = 0.0f;
- for (int j = 0; j < QK4_0; j++) {
+ for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
}
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
- const int ix = tiisg/8; // 0...3
- const int it = tiisg%8; // 0...7
- const int iq = it/4; // 0 or 1
- const int ir = it%4; // 0...3
- const int is = (8*ir)/16;// 0 or 1
+ const short ix = tiisg/8; // 0...3
+ const short it = tiisg%8; // 0...7
+ const short iq = it/4; // 0 or 1
+ const short ir = it%4; // 0...3
+ const short is = (8*ir)/16;// 0 or 1
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
for (int ib = ix; ib < nb; ib += 4) {
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; ++i) {
+ for (short i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
- for (int row = 0; row < N_DST; row++) {
+ for (short row = 0; row < nr0; row++) {
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
//const uint16_t kmask1 = 0x3030;
//const uint16_t kmask2 = 0x0f0f;
- const int tid = tiisg/4;
- const int ix = tiisg%4;
- const int ip = tid/4; // 0 or 1
- const int il = 2*((tid%4)/2); // 0 or 2
- const int ir = tid%2;
- const int n = 8;
- const int l0 = n*ir;
+ const short tid = tiisg/4;
+ const short ix = tiisg%4;
+ const short ip = tid/4; // 0 or 1
+ const short il = 2*((tid%4)/2); // 0 or 2
+ const short ir = tid%2;
+ const short l0 = 8*ir;
// One would think that the Metal compiler would figure out that ip and il can only have
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
const uint16_t s_shift1 = 4*ip;
const uint16_t s_shift2 = s_shift1 + il;
- const int q_offset = 32*ip + l0;
- const int y_offset = 128*ip + 32*il + l0;
+ const short q_offset = 32*ip + l0;
+ const short y_offset = 128*ip + 32*il + l0;
device const float * y1 = yy + ix*QK_K + y_offset;
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
thread const int8_t * scales = (thread const int8_t *)&scales32;
- float sumf1[2] = {0.f};
- float sumf2[2] = {0.f};
+ float sumf1[nr0] = {0.f};
+ float sumf2[nr0] = {0.f};
+
for (int i = ix; i < nb; i += 4) {
- for (int l = 0; l < 8; ++l) {
+ for (short l = 0; l < 8; ++l) {
yl[l+ 0] = y1[l+ 0];
yl[l+ 8] = y1[l+16];
yl[l+16] = y1[l+32];
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
device const half * dh = &x[i].d;
- for (int row = 0; row < 2; ++row) {
+ for (short row = 0; row < nr0; ++row) {
const float d_all = (float)dh[0];
scales16[0] = a[4];
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
- for (int l = 0; l < n; l += 2) {
+ for (short l = 0; l < 8; l += 2) {
const int32_t qs = q[l/2];
s1 += yl[l+0] * (qs & qm[il/2][0]);
s2 += yl[l+1] * (qs & qm[il/2][1]);
sumf2[row] += d2 * (scales[2] - 32);
s1 = s2 = s3 = s4 = s5 = s6 = 0;
- for (int l = 0; l < n; l += 2) {
+ for (short l = 0; l < 8; l += 2) {
const int32_t qs = q[l/2+8];
s1 += yl[l+8] * (qs & qm[il/2][0]);
s2 += yl[l+9] * (qs & qm[il/2][1]);
y1 += 4 * QK_K;
}
- for (int row = 0; row < 2; ++row) {
+ for (int row = 0; row < nr0; ++row) {
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
sumf1[row] = simd_sum(sumf);
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
if (tiisg == 0) {
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
dst_f32[first_row + row] = sumf1[row];
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
-
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
- const int ix = tiisg/8; // 0...3
- const int it = tiisg%8; // 0...7
- const int iq = it/4; // 0 or 1
- const int ir = it%4; // 0...3
+ const short ix = tiisg/8; // 0...3
+ const short it = tiisg%8; // 0...7
+ const short iq = it/4; // 0 or 1
+ const short ir = it%4; // 0...3
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int first_row = r0 * N_DST;
+
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
float yl[16];
float yh[16];
- float sumf[N_DST]={0.f}, all_sum;
+
+ float sumf[nr0]={0.f};
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
for (int ib = ix; ib < nb; ib += 4) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; ++i) {
+
+ for (short i = 0; i < 8; ++i) {
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
- for (int row = 0; row < N_DST; row++) {
+ for (short row = 0; row < nr0; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; i += 2) {
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+
+ for (short i = 0; i < 4; ++i) {
+ acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
+ acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
+ acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
+ acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
+ acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
+ acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
+ acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
+ acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
- float sumf[2]={0.f};
+ float sumf[nr0]={0.f};
float yl[16], yh[16];
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
- const int tid = tiisg/4;
- const int ix = tiisg%4;
- const int iq = tid/4;
- const int ir = tid%4;
- const int n = 8;
+ const short tid = tiisg/4;
+ const short ix = tiisg%4;
+ const short iq = tid/4;
+ const short ir = tid%4;
- const int l0 = n*ir;
- const int q_offset = 32*iq + l0;
- const int y_offset = 64*iq + l0;
+ const short l0 = 8*ir;
+ const short q_offset = 32*iq + l0;
+ const short y_offset = 64*iq + l0;
const uint8_t hm1 = 1u << (2*iq);
const uint8_t hm2 = hm1 << 1;
device const float * y2 = y1 + 128;
float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int l = 0; l < 8; ++l) {
+ for (short l = 0; l < 8; ++l) {
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
}
- for (int row = 0; row < 2; ++row) {
+ for (short row = 0; row < nr0; ++row) {
device const uint8_t * q2 = q1 + 64;
sc16[0] = a[0] & kmask1;
float4 acc1 = {0.f};
float4 acc2 = {0.f};
- for (int l = 0; l < n; ++l) {
+ for (short l = 0; l < 8; ++l) {
uint8_t h = qh[l];
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = tot;
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int row = 2*r0 + sgitg;
-
- if (row >= args.ne0) {
- return;
- }
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
- float sumf = 0;
+ float sumf[nr0] = { 0.f };
+
+ float yl[16];
- const int tid = tiisg/2;
- const int ix = tiisg%2;
- const int ip = tid/8; // 0 or 1
- const int il = tid%8;
- const int n = 4;
- const int l0 = n*il;
- const int is = 8*ip + l0/16;
+ const short tid = tiisg/2;
+ const short ix = tiisg%2;
+ const short ip = tid/8; // 0 or 1
+ const short il = tid%8;
+ const short l0 = 4*il;
+ const short is = 8*ip + l0/16;
- const int y_offset = 128*ip + l0;
- const int q_offset_l = 64*ip + l0;
- const int q_offset_h = 32*ip + l0;
+ const short y_offset = 128*ip + l0;
+ const short q_offset_l = 64*ip + l0;
+ const short q_offset_h = 32*ip + l0;
for (int i = ix; i < nb; i += 2) {
device const uint8_t * q1 = x[i].ql + q_offset_l;
device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
device const int8_t * sc = x[i].scales + is;
+ device const half * dh = &x[i].d;
device const float * y = yy + i * QK_K + y_offset;
- const float dall = x[i].d;
-
- float4 sums = {0.f, 0.f, 0.f, 0.f};
- for (int l = 0; l < n; ++l) {
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ for (short l = 0; l < 4; ++l) {
+ yl[4*l + 0] = y[l + 0];
+ yl[4*l + 1] = y[l + 32];
+ yl[4*l + 2] = y[l + 64];
+ yl[4*l + 3] = y[l + 96];
}
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+ for (short row = 0; row < nr0; ++row) {
+ const float dall = dh[0];
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (short l = 0; l < 4; ++l) {
+ sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+
+ sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+ q1 += args.nb01;
+ q2 += args.nb01;
+ qh += args.nb01;
+ sc += args.nb01;
+ dh += args.nb01/2;
+ }
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- const float tot = simd_sum(sumf);
- if (tiisg == 0) {
- dst_f32[row] = tot;
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = sum_all;
+ }
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const uint16_t * q2 = xr->qs + 4 * ib;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
-
+ for (short row = 0; row < nr0; row++) {
const float db = dh[0];
device const uint8_t * aux8 = (device const uint8_t *)q2;
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float sum = 0;
- for (int l = 0; l < 4; ++l) {
+ for (short l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
- for (int j = 0; j < 8; ++j) {
+ for (short j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const uint8_t * sc = xr->scales + ib;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
-
+ for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const uint8_t ls1 = sc[0] & 0xf;
const uint8_t ls2 = sc[0] >> 4;
const float d2 = db * (0.5f + ls2);
float sum1 = 0, sum2 = 0;
- for (int l = 0; l < 2; ++l) {
+ for (short l = 0; l < 2; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
- for (int j = 0; j < 8; ++j) {
+ for (short j = 0; j < 8; ++j) {
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
- for (int l = 2; l < 4; ++l) {
+ for (short l = 2; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
- for (int j = 0; j < 8; ++j) {
+ for (short j = 0; j < 8; ++j) {
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
+ for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float2 sum = {0};
- for (int l = 0; l < 4; ++l) {
+ for (short l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
- for (int j = 0; j < 4; ++j) {
+ for (short j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum * 0.5f;
+ dst_f32[first_row + row] = sum_all * 0.5f;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const uint8_t * signs = xr->signs + 4 * ib;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
-
+ for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
- for (int l = 0; l < 4; ++l) {
+ for (short l = 0; l < 4; ++l) {
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
- for (int j = 0; j < 4; ++j) {
+ for (short j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
// threadgroup_barrier(mem_flags::mem_threadgroup);
//}
- const int ix = tiisg;
+ const short ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const uint8_t * signs = qs + QK_K/8;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
-
+ for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const float d1 = db * (0.5f + (sc[0] & 0xf));
const float d2 = db * (0.5f + (sc[0] >> 4));
float2 sum = {0};
- for (int l = 0; l < 2; ++l) {
+ for (short l = 0; l < 2; ++l) {
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
- for (int j = 0; j < 8; ++j) {
+ for (short j = 0; j < 8; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const char * src0,
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
- const int ix = tiisg;
+ const short ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
float sumy = 0;
- for (int i = 0; i < 32; ++i) {
+ for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
sumy += yl[i];
}
device const uint16_t * qh = xr->qh + ib;
device const half * dh = &xr->d;
- for (int row = 0; row < N_DST; row++) {
-
+ for (short row = 0; row < nr0; row++) {
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
float sum = 0;
- for (int j = 0; j < 4; ++j) {
+ for (short j = 0; j < 4; ++j) {
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
-template <typename args_t>
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const char * src0,
ushort sgitg) {
const int nb = args.ne00/QK_K;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
- const int ix = tiisg;
+ const short ix = tiisg;
device const float * y4 = y + 32 * ix;
iq1m_scale_t scale;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
float4 sumy = {0.f};
- for (int i = 0; i < 8; ++i) {
+ for (short i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
device const uint8_t * qh = xr->qh + 2 * ib;
device const uint16_t * sc = (device const uint16_t *)xr->scales;
- for (int row = 0; row < N_DST; row++) {
+ for (short row = 0; row < nr0; row++) {
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
float2 sum = {0.f};
- for (int j = 0; j < 4; ++j) {
+ for (short j = 0; j < 4; ++j) {
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
-template<typename args_t>
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
+
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * 2 + sgitg) * 2;
+
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
- const int ix = tiisg/2; // 0...15
- const int it = tiisg%2; // 0 or 1
+ const short ix = tiisg/2; // 0...15
+ const short it = tiisg%2; // 0 or 1
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
- float sumf[2]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
device const float * yb = y + ix * QK4_NL + it * 8;
float4 qf1, qf2;
for (int ib = ix; ib < nb; ib += 16) {
-
device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
- for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+ yl[0] = y4[0];
+ yl[1] = y4[4];
+ yl[2] = y4[1];
+ yl[3] = y4[5];
+ for (short row = 0; row < nr0; row++) {
device const block_iq4_nl & xb = x[row*nb + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
acc1 += acc2;
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
}
yb += 16 * QK4_NL;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
-template<typename args_t>
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * 2 + sgitg) * 2;
+ const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
- const int ix = tiisg/16; // 0 or 1
- const int it = tiisg%16; // 0...15
- const int ib = it/2;
- const int il = it%2;
+ const short ix = tiisg/16; // 0 or 1
+ const short it = tiisg%16; // 0...15
+ const short ib = it/2;
+ const short il = it%2;
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
- float sumf[2]={0.f}, all_sum;
+ float sumf[nr0]={0.f};
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+ yl[0] = y4[0];
+ yl[1] = y4[4];
+ yl[2] = y4[1];
+ yl[3] = y4[5];
- for (int row = 0; row < 2; ++row) {
+ for (short row = 0; row < nr0; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
}
yb += 2 * QK_K;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
- all_sum = simd_sum(sumf[row]);
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst_f32[first_row + row] = all_sum;
+ dst_f32[first_row + row] = sum_all;
}
}
}
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
-}
-
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32(
constant ggml_metal_kargs_mul_mv & args,
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
#endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
kernel void kernel_pool_2d_max_f32(
device const float * src0,