const int64_t i3 = tgpig.z;
const int64_t nc = ne10;
- const int64_t ncs = ne00;
- const int64_t nr = ne01;
- const int64_t n_t = ne1;
- const int64_t n_s = ne2;
+ //const int64_t ncs = ne00;
+ //const int64_t nr = ne01;
+ //const int64_t n_t = ne1;
+ //const int64_t n_s = ne2;
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
const int64_t i3 = tgpig.y;
const int64_t nc = d_state;
- const int64_t nr = d_inner;
+ //const int64_t nr = d_inner;
const int64_t n_t = n_seq_tokens;
- const int64_t n_s = n_seqs;
+ //const int64_t n_s = n_seqs;
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
- float2 acc = 0.f;
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
- + yl[i + 9] * (qs[i / 2] & 0xF000);
+ for (int i = 0; i < 8; i += 2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
}
- return d * (sumy * -8.f + acc[0] + acc[1]);
+
+ return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
}
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
float d = qb_curr->d;
float m = qb_curr->m;
- float2 acc = 0.f;
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
- + yl[i + 9] * (qs[i / 2] & 0xF000);
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
}
- return d * (acc[0] + acc[1]) + sumy * m;
+
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
- float2 acc = 0.f;
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
- return d * (sumy * -16.f + acc[0] + acc[1]);
+
+ return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
}
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
float d = qb_curr->d;
float m = qb_curr->m;
- float2 acc = 0.f;
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
- return d * (acc[0] + acc[1]) + sumy * m;
+
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
// putting them in the kernel cause a significant performance penalty
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
uint r3,
threadgroup int8_t * shared_values,
- uint3 tgpig, uint tiisg, uint sgitg) {
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q_type * x = (device const block_q_type *) src0 + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
+
+ // pointers to src0 rows
+ device const block_q_type * ax[nr];
+ for (int row = 0; row < nr; ++row) {
+ const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+ ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+ }
float yl[16]; // src1 vector cache
float sumf[nr] = {0.f};
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
- float sumy = 0;
+ float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
for (int i = 0; i < 8; i += 2) {
- sumy += yb[i] + yb[i+1];
- yl[i+0] = yb[i+ 0];
- yl[i+1] = yb[i+ 1]/256.f;
+ sumy[0] += yb[i + 0] + yb[i + 1];
+ yl[i + 0] = yb[i + 0];
+ yl[i + 1] = yb[i + 1]/256.f;
- sumy += yb[i+16] + yb[i+17];
- yl[i+8] = yb[i+16]/16.f;
- yl[i+9] = yb[i+17]/4096.f;
+ sumy[1] += yb[i + 16] + yb[i + 17];
+ yl[i + 8] = yb[i + 16]/16.f;
+ yl[i + 9] = yb[i + 17]/4096.f;
}
+#pragma unroll
for (int row = 0; row < nr; row++) {
- sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}
yb += QK4_0 * 16;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
+
+ // pointers to src0 rows
+ device const block_q8_0 * ax[nr];
+ for (int row = 0; row < nr; ++row) {
+ const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+ ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+ }
float yl[NB_Q8_0];
float sumf[nr]={0.f};
}
for (int row = 0; row < nr; row++) {
- device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+ device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
- sumf[row] += sumq*x[ib+row*nb].d;
+ sumf[row] += sumq*ax[row][ib].d;
}
yb += NB_Q8_0 * nw;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
#define N_MV_T_T 4
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
device const T0 * x = (device const T0 *) (src0 + offset0);
break;
}
- device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const T1 * y = (device const T1 *) (src1 + offset1);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
break;
}
- device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y4 = (device const T14 *) y;
float sumf = 0;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
nb00,
nb01,
nb02,
+ nb03,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
+ nb13,
ne0,
ne1,
r2,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
device const T * x = (device const T *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float * y = (device const float *) (src1 + offset1);
float sumf = 0;
if (ne00 < 128) {
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
device const T4 * x4 = (device const T4 *) (src0 + offset0);
for (int r1 = 0; r1 < nrows; ++r1) {
- device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const float4 * y4 = (device const float4 *) (src1 + offset1);
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
- const int step = sizeof(block_q2_K) * nb;
-
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
- qs += step/2;
- sc += step;
- dh += step/2;
+ qs += nb01/2;
+ sc += nb01;
+ dh += nb01/2;
}
y4 += 4 * QK_K;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q3_K_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
float yl[32];
const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0;
- const int step = sizeof(block_q3_K) * nb / 2;
-
device const float * y1 = yy + ix*QK_K + y_offset;
uint32_t scales32, aux32;
float sumf1[2] = {0.f};
float sumf2[2] = {0.f};
for (int i = ix; i < nb; i += 4) {
-
for (int l = 0; l < 8; ++l) {
yl[l+ 0] = y1[l+ 0];
yl[l+ 8] = y1[l+16];
device const half * dh = &x[i].d;
for (int row = 0; row < 2; ++row) {
-
const float d_all = (float)dh[0];
scales16[0] = a[4];
sumf1[row] += d1 * (scales[1] - 32);
sumf2[row] += d2 * (scales[3] - 32);
- q += step;
- h += step;
- a += step;
- dh += step;
-
+ q += nb01/2;
+ h += nb01/2;
+ a += nb01/2;
+ dh += nb01/2;
}
y1 += 4 * QK_K;
-
}
for (int row = 0; row < 2; ++row) {
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q4_K_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int first_row = r0 * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[16];
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
- const int step = sizeof(block_q4_K) * nb / 2;
-
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
uint16_t sc16[4];
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
for (int ib = ix; ib < nb; ib += 4) {
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
device const half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
-
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += step;
- sc += step;
- dh += step;
+ q1 += nb01/2;
+ sc += nb01/2;
+ dh += nb01/2;
}
y4 += 4 * QK_K;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q5_K_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
float sumf[2]={0.f};
- const int step = sizeof(block_q5_K) * nb;
-
float yl[16], yh[16];
const uint16_t kmask1 = 0x3f3f;
device const float * y1 = yy + ix*QK_K + y_offset;
for (int i = ix; i < nb; i += 4) {
-
device const uint8_t * q1 = x[i].qs + q_offset;
device const uint8_t * qh = x[i].qh + l0;
device const half * dh = &x[i].d;
}
for (int row = 0; row < 2; ++row) {
-
device const uint8_t * q2 = q1 + 64;
sc16[0] = a[0] & kmask1;
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += step;
- qh += step;
- dh += step/2;
- a += step/2;
-
+ q1 += nb01;
+ qh += nb01;
+ dh += nb01/2;
+ a += nb01/2;
}
y1 += 4 * QK_K;
-
}
for (int row = 0; row < 2; ++row) {
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q6_K_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
float sumf = 0;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += d * sum;
- dh += nb*sizeof(block_iq2_xxs)/2;
- q2 += nb*sizeof(block_iq2_xxs)/2;
+ dh += nb01/2;
+ q2 += nb01/2;
}
y4 += 32 * 32;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq2_xs_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += d1 * sum1 + d2 * sum2;
- dh += nb*sizeof(block_iq2_xs)/2;
- q2 += nb*sizeof(block_iq2_xs)/2;
- sc += nb*sizeof(block_iq2_xs);
+ dh += nb01/2;
+ q2 += nb01/2;
+ sc += nb01;
}
y4 += 32 * 32;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq3_xxs_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += d * (sum[0] + sum[1]);
- dh += nb*sizeof(block_iq3_xxs)/2;
- q3 += nb*sizeof(block_iq3_xxs);
- gas += nb*sizeof(block_iq3_xxs)/2;
+ dh += nb01/2;
+ q3 += nb01;
+ gas += nb01/2;
}
y4 += 32 * 32;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq3_s_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += d * (sum[0] + sum[1]);
- dh += nb*sizeof(block_iq3_s)/2;
- qs += nb*sizeof(block_iq3_s);
- qh += nb*sizeof(block_iq3_s);
- sc += nb*sizeof(block_iq3_s);
- signs += nb*sizeof(block_iq3_s);
+ dh += nb01/2;
+ qs += nb01;
+ qh += nb01;
+ sc += nb01;
+ signs += nb01;
}
y4 += 32 * 32;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq2_s_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
- device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += d1 * sum[0] + d2 * sum[1];
- dh += nb*sizeof(block_iq2_s)/2;
- qs += nb*sizeof(block_iq2_s);
- qh += nb*sizeof(block_iq2_s);
- sc += nb*sizeof(block_iq2_s);
- signs += nb*sizeof(block_iq2_s);
+ dh += nb01/2;
+ qs += nb01;
+ qh += nb01;
+ sc += nb01;
+ signs += nb01;
}
y4 += 32 * 32;
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq1_s_f32_impl(
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
- dh += nb*sizeof(block_iq1_s)/2;
- qs += nb*sizeof(block_iq1_s);
- qh += nb*sizeof(block_iq1_s)/2;
+ dh += nb01/2;
+ qs += nb01;
+ qh += nb01/2;
}
y4 += 32 * 32;
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
- sc += nb*sizeof(block_iq1_m)/2;
- qs += nb*sizeof(block_iq1_m);
- qh += nb*sizeof(block_iq1_m);
+ sc += nb01/2;
+ qs += nb01;
+ qh += nb01;
}
y4 += 32 * 32;
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * 2 + sgitg) * 2;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * 2 + sgitg) * 2;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+
+ device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
//============================= templates and their specializations =============================
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
- + nb12 * im
+ + nb13 * i13
+ + nb12 * i12
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
int64_t ne00,
int64_t ne01,
int64_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne12,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint r2,
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint64_t nb1,
uint tiitg,
uint tiisg,
uint sgitg) {
- impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
}
template<kernel_mul_mv2_impl_t impl_fn>
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
+ uint64_t nb03,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
+ uint64_t nb13,
int64_t ne0,
int64_t ne1,
uint64_t nb1,
uint tiitg,
uint tiisg,
uint sgitg) {
- impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
const int64_t i2 = i12;
device const char * src0_cur = src0s + i02*nb02;
- device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
- device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
impl_fn(
/* src0 */ src0_cur,
/* dst */ dst_cur,
/* ne00 */ ne00,
/* ne01 */ ne01,
- /* ne02 */ 1,//ne02,
+ /* ne02 */ 1, // ne02,
/* nb00 */ nb00,
/* nb01 */ nb01,
/* nb02 */ nb02,
+ /* nb03 */ nb02, // ne02 == 1
/* ne10 */ ne10,
- /* ne11 */ 1,//ne11,
- /* ne12 */ 1,//ne12,
- /* ne13 */ 1,//ne13,
+ /* ne11 */ 1, // ne11,
+ /* ne12 */ 1, // ne12,
+ /* ne13 */ 1, // ne13,
/* nb10 */ nb10,
/* nb11 */ nb11,
/* nb12 */ nb12,
+ /* ne13 */ nb12, // ne12 == 1
/* ne0 */ ne0,
- /* ne1 */ 1,//ne1,
+ /* ne1 */ 1, // ne1,
/* nb1 */ nb1,
/* r2 */ 1,
/* r3 */ 1,