return scale;
}
-#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
-#endif
//========================- 2-bit (de)-quantization
}
}
-#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
-#else
- for (int l = 0; l < 16; ++l) {
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
- }
-#endif
x += QK_K;
-
}
}
const uint8_t * q = x[i].qs;
-#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
}
q += 32;
}
-#else
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
- for (int l = 0; l < 16; ++l) {
- y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
- y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
- y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
- y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
- }
- y += QK_K;
-#endif
}
}
}
float dm, mm;
-#if QK_K == 64
- float max_scale = 0, max_min = 0;
- for (int j = 0; j < QK_K/16; ++j) {
- max_scale = MAX(max_scale, scales[j]);
- max_min = MAX(max_min, mins[j]);
- }
- dm = max_scale/15;
- mm = max_min/15;
- if (max_scale) {
- float id = 1/dm;
- for (int j = 0; j < QK_K/16; ++j) {
- int l = nearest_int(id*scales[j]);
- Ls[j] = MAX(0, MIN(15, l));
- }
- } else {
- memset(Ls, 0, QK_K/16);
- }
- if (max_min) {
- float id = 1/mm;
- for (int j = 0; j < QK_K/16; ++j) {
- int l = nearest_int(id*mins[j]);
- Lm[j] = MAX(0, MIN(15, l));
- }
- } else {
- memset(Lm, 0, QK_K/16);
- }
-#else
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
-#endif
+
y[i].d = GGML_FP32_TO_FP16(dm);
y[i].dmin = GGML_FP32_TO_FP16(mm);
dm = GGML_FP16_TO_FP32(y[i].d);
}
}
-#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
-#else
- for (int l = 0; l < 16; ++l) {
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
- }
-#endif
x += QK_K;
-
}
}
}
}
-#if QK_K == 256
memset(y[i].scales, 0, 12);
if (max_scale) {
float iscale = -32.f/max_scale;
L[16*j + ii] = l + 4;
}
}
-#else
- if (max_scale) {
- float iscale = -8.f/max_scale;
- for (int j = 0; j < QK_K/16; j+=2) {
- int l1 = nearest_int(iscale*scales[j]);
- l1 = 8 + MAX(-8, MIN(7, l1));
- int l2 = nearest_int(iscale*scales[j+1]);
- l2 = 8 + MAX(-8, MIN(7, l2));
- y[i].scales[j/2] = l1 | (l2 << 4);
- }
- y[i].d = GGML_FP32_TO_FP16(1/iscale);
- } else {
- for (int j = 0; j < QK_K/16; j+=2) {
- y[i].scales[j/2] = 0;
- }
- y[i].d = GGML_FP32_TO_FP16(0.f);
- }
- for (int j = 0; j < QK_K/16; ++j) {
- int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
- float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8);
- if (!d) {
- continue;
- }
- for (int ii = 0; ii < 16; ++ii) {
- int l = nearest_int(x[16*j + ii]/d);
- l = MAX(-4, MIN(3, l));
- L[16*j + ii] = l + 4;
- }
- }
-#endif
memset(y[i].hmask, 0, QK_K/8);
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
m = 0; hm <<= 1;
}
}
-#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
-#else
- for (int l = 0; l < 16; ++l) {
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
- }
-#endif
x += QK_K;
}
}
-#if QK_K == 256
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
}
}
-#else
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
- assert(k % QK_K == 0);
- assert(QK_K == 64);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- const float d_all = GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q = x[i].qs;
- const uint8_t * restrict hm = x[i].hmask;
-
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
-
- for (int l=0; l<8; ++l) {
- uint8_t h = hm[l];
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
- }
- y += QK_K;
- }
-}
-#endif
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q3_K_reference(x, vy, k);
}
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
-#if QK_K != 256
- (void)quant_weights;
- quantize_row_q3_K_reference(x, y, n_per_row);
-#else
assert(n_per_row % QK_K == 0);
const int nb = n_per_row / QK_K;
x += QK_K;
}
-#endif
}
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
float scales[QK_K/32];
for (int i = 0; i < nb; i++) {
-
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
}
}
-#if QK_K == 256
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
L[32*j + ii] = l;
}
}
-#else
- const float s_factor = 15.f;
- float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
- float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
- int d1 = nearest_int(inv_scale*scales[0]);
- int m1 = nearest_int(inv_min*mins[0]);
- int d2 = nearest_int(inv_scale*scales[1]);
- int m2 = nearest_int(inv_min*mins[1]);
- y[i].scales[0] = d1 | (m1 << 4);
- y[i].scales[1] = d2 | (m2 << 4);
- y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor);
- y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor);
- float sumlx = 0;
- int suml2 = 0;
- for (int j = 0; j < QK_K/32; ++j) {
- const uint8_t sd = y[i].scales[j] & 0xF;
- const uint8_t sm = y[i].scales[j] >> 4;
- const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd;
- if (!d) continue;
- const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm;
- for (int ii = 0; ii < 32; ++ii) {
- int l = nearest_int((x[32*j + ii] + m)/d);
- l = MAX(0, MIN(15, l));
- L[32*j + ii] = l;
- sumlx += (x[32*j + ii] + m)*l*sd;
- suml2 += l*l*sd*sd;
- }
- }
- if (suml2) {
- y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2);
- }
-#endif
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
}
x += QK_K;
-
}
}
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
-
const uint8_t * q = x[i].qs;
-#if QK_K == 256
-
const float d = GGML_FP16_TO_FP32(x[i].d);
const float min = GGML_FP16_TO_FP32(x[i].dmin);
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
-#else
- const float dall = GGML_FP16_TO_FP32(x[i].d[0]);
- const float mall = GGML_FP16_TO_FP32(x[i].d[1]);
- const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
- const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
- for (int l = 0; l < 32; ++l) {
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
- y[l+32] = d2 * (q[l] >> 4) - m2;
- }
- y += QK_K;
-#endif
-
}
}
}
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
-#if QK_K != 256
- (void)quant_weights;
- quantize_row_q4_K_reference(x, y, n_per_row);
-#else
assert(n_per_row % QK_K == 0);
const int64_t nb = n_per_row / QK_K;
x += QK_K;
}
-#endif
}
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
assert(k % QK_K == 0);
const int64_t nb = k / QK_K;
-#if QK_K == 256
uint8_t L[QK_K];
float mins[QK_K/32];
float scales[QK_K/32];
float weights[32];
uint8_t Laux[32];
-#else
- int8_t L[QK_K];
- float scales[QK_K/16];
-#endif
for (int i = 0; i < nb; i++) {
-
-#if QK_K == 256
-
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
m1 <<= 2; m2 <<= 2;
ql += 32;
}
-#else
- float max_scale = 0, amax = 0;
- for (int j = 0; j < QK_K/16; ++j) {
- scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL);
- float abs_scale = fabsf(scales[j]);
- if (abs_scale > amax) {
- amax = abs_scale;
- max_scale = scales[j];
- }
- }
-
- float iscale = -128.f/max_scale;
- for (int j = 0; j < QK_K/16; ++j) {
- int l = nearest_int(iscale*scales[j]);
- y[i].scales[j] = MAX(-128, MIN(127, l));
- }
- y[i].d = GGML_FP32_TO_FP16(1/iscale);
-
- for (int j = 0; j < QK_K/16; ++j) {
- const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
- if (!d) continue;
- for (int ii = 0; ii < 16; ++ii) {
- int l = nearest_int(x[16*j + ii]/d);
- l = MAX(-16, MIN(15, l));
- L[16*j + ii] = l + 16;
- }
- }
-
- uint8_t * restrict qh = y[i].qh;
- uint8_t * restrict ql = y[i].qs;
- memset(qh, 0, QK_K/8);
-
- for (int j = 0; j < 32; ++j) {
- int jm = j%8;
- int is = j/8;
- int l1 = L[j];
- if (l1 > 15) {
- l1 -= 16; qh[jm] |= (1 << is);
- }
- int l2 = L[j + 32];
- if (l2 > 15) {
- l2 -= 16; qh[jm] |= (1 << (4 + is));
- }
- ql[j] = l1 | (l2 << 4);
- }
-#endif
x += QK_K;
-
}
}
const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
-
const uint8_t * ql = x[i].qs;
const uint8_t * qh = x[i].qh;
-#if QK_K == 256
-
const float d = GGML_FP16_TO_FP32(x[i].d);
const float min = GGML_FP16_TO_FP32(x[i].dmin);
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
-#else
- float d = GGML_FP16_TO_FP32(x[i].d);
- const int8_t * restrict s = x[i].scales;
- for (int l = 0; l < 8; ++l) {
- y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
- y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
- y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
- y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
- y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
- y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
- y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
- y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
- }
- y += QK_K;
-#endif
}
}
}
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
-#if QK_K != 256
- (void)quant_weights;
- quantize_row_q5_K_reference(x, y, n_per_row);
-#else
assert(n_per_row % QK_K == 0);
const int64_t nb = n_per_row / QK_K;
x += QK_K;
}
-#endif
}
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
-#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
ql += 64;
qh += 32;
}
-#else
- for (int l = 0; l < 32; ++l) {
- const uint8_t q1 = L[l + 0] & 0xF;
- const uint8_t q2 = L[l + 32] & 0xF;
- ql[l] = q1 | (q2 << 4);
- }
- for (int l = 0; l < 16; ++l) {
- qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
- }
-#endif
x += QK_K;
-
}
}
const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
-
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
-#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
qh += 32;
sc += 8;
}
-#else
- for (int l = 0; l < 16; ++l) {
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l+ 0] = d * sc[0] * q1;
- y[l+16] = d * sc[1] * q2;
- y[l+32] = d * sc[2] * q3;
- y[l+48] = d * sc[3] * q4;
- }
- y += 64;
-#endif
-
}
}
}
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
-#if QK_K != 256
- (void)quant_weights;
- quantize_row_q6_K_reference(x, y, n_per_row);
-#else
assert(n_per_row % QK_K == 0);
const int64_t nb = n_per_row / QK_K;
x += QK_K;
}
-#endif
}
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
float delta[4];
uint16_t idx[4];
-#if QK_K != 64
iq1m_scale_t scale;
-#endif
for (int i = 0; i < nb; i++) {
const uint16_t * sc = (const uint16_t *)x[i].scales;
-#if QK_K == 64
- const float d = GGML_FP16_TO_FP32(x[i].d);
-#else
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = GGML_FP16_TO_FP32(scale.f16);
-#endif
+
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;
for (int ib = 0; ib < QK_K/32; ++ib) {
-#if QK_K == 64
- const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1);
- const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1);
-#else
const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
-#endif
+
idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
-#if QK_K == 64
- dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
-#else
const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
qs += 16;
}
}
-#endif
}
//===================================== Q8_K ==============================================
#endif
}
-#if QK_K == 256
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
#endif
}
-#else
-
-void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
- const block_q2_K * restrict x = vx;
+ const uint32_t kmask1 = 0x03030303;
+ const uint32_t kmask2 = 0x0f0f0f0f;
+
+ const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
- const uint8x16_t m3 = vdupq_n_u8(0x3);
- const int32x4_t vzero = vdupq_n_s32(0);
+ uint32_t aux[3];
+ uint32_t utmp[4];
+
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
+ const int32x4_t vzero = vdupq_n_s32(0);
- ggml_int8x16x4_t q2bytes;
+ const uint8x16_t m0 = vdupq_n_u8(1);
+ const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+ const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+ const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+ const int8_t m32 = 32;
- uint32_t aux32[2];
- const uint8_t * scales = (const uint8_t *)aux32;
+ ggml_int8x16x4_t q3bytes;
float sum = 0;
for (int i = 0; i < nb; ++i) {
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const uint8_t * restrict q2 = x[i].qs;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
- const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
-
- aux32[0] = sc[0] & 0x0f0f0f0f;
- aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
-
- sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
- int isum1 = 0, isum2 = 0;
-
- const uint8x16_t q2bits = vld1q_u8(q2);
-
- const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
-
- q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
- q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
- q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
- q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
-
- isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
- isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
- isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
- isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
-
- sum += d * (isum1 + isum2);
- }
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- *s = sum;
+ ggml_uint8x16x4_t q3h;
-#elif defined __AVX2__
+ int32_t isum = 0;
- const __m256i m3 = _mm256_set1_epi8(3);
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
- __m256 acc = _mm256_setzero_ps();
+ int8_t * scale = (int8_t *)utmp;
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
- uint32_t ud, um;
- const uint8_t * restrict db = (const uint8_t *)&ud;
- const uint8_t * restrict mb = (const uint8_t *)&um;
+ for (int j = 0; j < QK_K/128; ++j) {
- float summs = 0;
+ const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+ const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
- // TODO: optimize this
+ q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+ q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+ q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+ q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
- for (int i = 0; i < nb; ++i) {
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
- const uint8_t * restrict q2 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
+ scale += 4;
- const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
- ud = (sc[0] >> 0) & 0x0f0f0f0f;
- um = (sc[0] >> 4) & 0x0f0f0f0f;
+ q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+ q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+ q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+ q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
- int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
- summs += dmin * smin;
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
- const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
- const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
- const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+ scale += 4;
- const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
- const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+ if (j == 0) {
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+ }
- const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
- const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
- const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
- const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
+ }
+ sum += d * isum;
- acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
- acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
- acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
- acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
}
- *s = hsum_float_8(acc) + summs;
+ *s = sum;
-#elif defined __AVX__
+#elif defined __AVX2__
- const __m128i m3 = _mm_set1_epi8(3);
+ const __m256i m3 = _mm256_set1_epi8(3);
+ const __m256i mone = _mm256_set1_epi8(1);
+ const __m128i m32 = _mm_set1_epi8(32);
__m256 acc = _mm256_setzero_ps();
- uint32_t ud, um;
- const uint8_t * restrict db = (const uint8_t *)&ud;
- const uint8_t * restrict mb = (const uint8_t *)&um;
-
- float summs = 0;
-
- // TODO: optimize this
+ uint32_t aux[3];
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q2 = x[i].qs;
+ const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
- ud = (sc[0] >> 0) & 0x0f0f0f0f;
- um = (sc[0] >> 4) & 0x0f0f0f0f;
-
- int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
- summs += dmin * smin;
-
- const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
- const __m128i q2_0 = _mm_and_si128(q2bits, m3);
- const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
- const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
- const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
-
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
-
- const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
- const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
- const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
- const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
-
- const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
- const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
- const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
- const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
-
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
- }
-
- *s = hsum_float_8(acc) + summs;
-
-#elif defined __riscv_v_intrinsic
-
- uint32_t aux32[2];
- const uint8_t * scales = (const uint8_t *)aux32;
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- const uint8_t * restrict q2 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
- const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
-
- aux32[0] = sc[0] & 0x0f0f0f0f;
- aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
-
- sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
-
- int isum1 = 0;
- int isum2 = 0;
-
- size_t vl = 16;
-
- vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
-
- // load Q2
- vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl);
-
- vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl));
- vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl));
- vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl));
- vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl));
-
- // load Q8, and take product with Q2
- vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
- vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
- vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
- vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
-
- vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl);
- vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl);
- vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl);
- vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl);
-
- isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0];
- isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1];
- isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2];
- isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3];
-
- sumf += d * (isum1 + isum2);
-
- }
-
- *s = sumf;
-
-
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0x3);
- const vector signed char lowScaleMask = vec_splats((signed char)0xF);
- const vector unsigned char v2 = vec_splats((unsigned char)0x2);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
- const vector unsigned char v6 = vec_splats((unsigned char)0x6);
-
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
-
-#pragma GCC unroll 2
- for (int i = 0; i < nb; ++i) {
- __builtin_prefetch(x[i].qs, 0, 1);
- __builtin_prefetch(y[i].qs, 0, 1);
-
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd = vec_mul(vxd, vyd);
-
- vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
- vector float vdmin = vec_mul(vxmin, vyd);
-
- vector signed short q8ysums0 = vec_xl_len(y[i].bsums, 8);
-
- vector signed char q2xmins = (vector signed char)vec_xl_len(x[i].scales, 4);
- vector signed char vscales = vec_and(q2xmins, lowScaleMask);
-
- q2xmins = vec_sr(q2xmins, v4);
- vector signed short q2xmins0 = vec_unpackh((vector signed char)q2xmins);
-
- vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);
- vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);
-
- vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
- vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs);
- vector signed char q2x00 = vec_and(qxs0, lowMask);
- vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask);
- vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask);
- vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask);
-
- vector signed char q8y00 = vec_xl( 0, y[i].qs);
- vector signed char q8y01 = vec_xl( 16, y[i].qs);
- vector signed char q8y02 = vec_xl( 32, y[i].qs);
- vector signed char q8y03 = vec_xl( 48, y[i].qs);
-
- vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00));
- vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01));
- vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02));
- vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03));
-
- vector signed short vscales_h = vec_unpackh(vscales);
- vector signed short vs0 = vec_splat(vscales_h, 0);
- vector signed short vs1 = vec_splat(vscales_h, 1);
- vector signed short vs2 = vec_splat(vscales_h, 2);
- vector signed short vs3 = vec_splat(vscales_h, 3);
-
- vector signed int vsumi0 = vec_add(vec_mule(qv0, vs0), vec_mulo(qv0, vs0));
- vector signed int vsumi1 = vec_add(vec_mule(qv1, vs1), vec_mulo(qv1, vs1));
- vector signed int vsumi2 = vec_add(vec_mule(qv2, vs2), vec_mulo(qv2, vs2));
- vector signed int vsumi3 = vec_add(vec_mule(qv3, vs3), vec_mulo(qv3, vs3));
-
- vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
- vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
- vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
- vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
- }
-
- vsumf0 = vec_add(vsumf0, vsumf2);
- vsumf1 = vec_add(vsumf1, vsumf3);
-
- vsumf0 = vec_add(vsumf0, vsumf1);
-
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
-
- *s = vec_extract(vsumf0, 0);
-
-#elif defined __loongarch_asx
-
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
-
- __m256 acc = (__m256)__lasx_xvldi(0);
-
- uint32_t ud, um;
- const uint8_t * restrict db = (const uint8_t *)&ud;
- const uint8_t * restrict mb = (const uint8_t *)&um;
-
- float summs = 0;
-
- // TODO: optimize this
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- const uint8_t * restrict q2 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
- ud = (sc[0] >> 0) & 0x0f0f0f0f;
- um = (sc[0] >> 4) & 0x0f0f0f0f;
-
- int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
- summs += dmin * smin;
-
- const __m128i q2bits = __lsx_vld((const __m128i*)q2, 0);
- const __m256i q2_0 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 2), q2bits), m3);
- const __m256i q2_1 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3);
-
- const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
- const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
-
- const __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
- const __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
-
- const __m256i p_0 = lasx_ext16_32(lasx_extracti128(p0, 0));
- const __m256i p_1 = lasx_ext16_32(lasx_extracti128(p0, 1));
- const __m256i p_2 = lasx_ext16_32(lasx_extracti128(p1, 0));
- const __m256i p_3 = lasx_ext16_32(lasx_extracti128(p1, 1));
-
- ft_union t0, t1, t2, t3;
- t0.f = d * db[0];
- t1.f = d * db[1];
- t2.f = d * db[2];
- t3.f = d * db[3];
- acc = __lasx_xvfmadd_s(__lasx_xvreplgr2vr_w(t0.i), __lasx_xvffint_s_w(p_0), acc);
- acc = __lasx_xvfmadd_s(__lasx_xvreplgr2vr_w(t1.i), __lasx_xvffint_s_w(p_1), acc);
- acc = __lasx_xvfmadd_s(__lasx_xvreplgr2vr_w(t2.i), __lasx_xvffint_s_w(p_2), acc);
- acc = __lasx_xvfmadd_s(__lasx_xvreplgr2vr_w(t3.i), __lasx_xvffint_s_w(p_3), acc);
- }
-
- *s = hsum_float_8(acc) + summs;
-
-#else
-
- float sumf = 0;
-
- int isum[QK_K/16];
-
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * q2 = x[i].qs;
- const int8_t * q8 = y[i].qs;
- const uint8_t * sc = x[i].scales;
-
- int summs = 0;
- for (int j = 0; j < QK_K/16; ++j) {
- summs += y[i].bsums[j] * (sc[j] >> 4);
- }
-
- const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- memset(isum, 0, (QK_K/16)*sizeof(int));
- for (int l = 0; l < 16; ++l) {
- isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
- isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
- isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
- isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
- }
- for (int l = 0; l < QK_K/16; ++l) {
- isum[l] *= (sc[l] & 0xF);
- }
- sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
- }
- *s = sumf;
-#endif
-}
-#endif
-
-#if QK_K == 256
-void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- const uint32_t kmask1 = 0x03030303;
- const uint32_t kmask2 = 0x0f0f0f0f;
-
- const block_q3_K * restrict x = vx;
- const block_q8_K * restrict y = vy;
-
- const int nb = n / QK_K;
-
-#ifdef __ARM_NEON
-
- uint32_t aux[3];
- uint32_t utmp[4];
-
- const uint8x16_t m3b = vdupq_n_u8(0x3);
- const int32x4_t vzero = vdupq_n_s32(0);
-
- const uint8x16_t m0 = vdupq_n_u8(1);
- const uint8x16_t m1 = vshlq_n_u8(m0, 1);
- const uint8x16_t m2 = vshlq_n_u8(m0, 2);
- const uint8x16_t m3 = vshlq_n_u8(m0, 3);
- const int8_t m32 = 32;
-
- ggml_int8x16x4_t q3bytes;
-
- float sum = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const uint8_t * restrict qh = x[i].hmask;
- const int8_t * restrict q8 = y[i].qs;
-
- ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
-
- ggml_uint8x16x4_t q3h;
-
- int32_t isum = 0;
-
- // Set up scales
- memcpy(aux, x[i].scales, 12);
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
-
- int8_t * scale = (int8_t *)utmp;
- for (int j = 0; j < 16; ++j) scale[j] -= m32;
-
- for (int j = 0; j < QK_K/128; ++j) {
-
- const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
- const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
- const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
-
- q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
- q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
- q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
- q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
-
- q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
- q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
- q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
- q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
-
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
-
- scale += 4;
-
- q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
- q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
- q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
- q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
-
- q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
- q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
- q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
- q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
-
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
-
- scale += 4;
-
- if (j == 0) {
- qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
- qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
- }
-
- }
- sum += d * isum;
-
- }
-
- *s = sum;
-
-#elif defined __AVX2__
-
- const __m256i m3 = _mm256_set1_epi8(3);
- const __m256i mone = _mm256_set1_epi8(1);
- const __m128i m32 = _mm_set1_epi8(32);
-
- __m256 acc = _mm256_setzero_ps();
-
- uint32_t aux[3];
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- // Set up scales
- memcpy(aux, x[i].scales, 12);
- __m128i scales128 = _mm_set_epi32(
- ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
- ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
- (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
- (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
- scales128 = _mm_sub_epi8(scales128, m32);
- const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
- const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
- const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
- const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
-
- // high bit
- const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
-
- // integer accumulator
- __m256i sumi = _mm256_setzero_si256();
-
- int bit = 0;
- int is = 0;
-
- for (int j = 0; j < QK_K/128; ++j) {
- // load low 2 bits
- const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
-
- // prepare low and high bits
- const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
- const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
- const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
- const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
- const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
- ++bit;
-
- // load Q8 quants
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
- __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
- __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
- __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
-
- __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
- __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
- __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
- __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
-
- p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
-
- // multiply with scales
- p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
- p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
- p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
- p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
-
- // accumulate
- p16_0 = _mm256_add_epi32(p16_0, p16_1);
- p16_2 = _mm256_add_epi32(p16_2, p16_3);
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
-
- }
-
- // multiply with block scale and accumulate
- acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
-
- }
-
- *s = hsum_float_8(acc);
-
-#elif defined __AVX__
-
- const __m128i m3 = _mm_set1_epi8(3);
- const __m128i mone = _mm_set1_epi8(1);
- const __m128i m32 = _mm_set1_epi8(32);
- const __m128i m2 = _mm_set1_epi8(2);
-
- __m256 acc = _mm256_setzero_ps();
-
- const uint32_t *aux;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- // Set up scales
- aux = (const uint32_t *)x[i].scales;
- __m128i scales128 = _mm_set_epi32(
- ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
- ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
- (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
- (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
- scales128 = _mm_sub_epi8(scales128, m32);
- const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
- const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
- const __m128i scales[2] = { scales_0, scales_1 };
-
- // high bit *128*2 from block_q3_K.hmask[QK_K/8]
- const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
- const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
-
- // integer accumulator
- __m128i sumi_0 = _mm_setzero_si128();
- __m128i sumi_1 = _mm_setzero_si128();
-
- for (int j = 0; j < QK_K/128; ++j) {
- // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
- const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
- const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
-
- // prepare low and high bits
- const int bit = j << 2;
-
- const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
- const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
- const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
- const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
-
- const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
- const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
- const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
- const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
-
- const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
- const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
- const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
- const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
-
- const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
- const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
- const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
- const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
-
- // load Q8 quants from block_q8_K.qs[QK_K]
- const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
- __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
- __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
- __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
- __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
- __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
- __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
- __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
-
- __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
- __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
- __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
- __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
- __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
- __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
- __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
- __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
-
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
-
- // multiply with scales
- __m128i shuffle = _mm_set1_epi16(0x0100);
- p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
- shuffle = _mm_add_epi16(shuffle, m2);
- p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
-
- // accumulate
- p16_0 = _mm_add_epi32(p16_0, p16_1);
- p16_2 = _mm_add_epi32(p16_2, p16_3);
- p16_4 = _mm_add_epi32(p16_4, p16_5);
- p16_6 = _mm_add_epi32(p16_6, p16_7);
- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
-
- }
-
- // multiply with block scale and accumulate
- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
-
- }
-
- *s = hsum_float_8(acc);
-
-#elif defined __riscv_v_intrinsic
-
- uint32_t aux[3];
- uint32_t utmp[4];
-
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * restrict q3 = x[i].qs;
- const uint8_t * restrict qh = x[i].hmask;
- const int8_t * restrict q8 = y[i].qs;
-
- memcpy(aux, x[i].scales, 12);
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
-
- int8_t * scale = (int8_t *)utmp;
- for (int j = 0; j < 16; ++j) scale[j] -= 32;
-
-
- size_t vl = 32;
- uint8_t m = 1;
-
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
-
- int sum_t = 0;
-
- for (int j = 0; j < QK_K; j += 128) {
-
- vl = 32;
-
- // load Q3
- vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
-
- vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
- vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
- vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
- vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
-
- // compute mask for subtraction
- vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
- vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
- m <<= 1;
-
- vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
- vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
- m <<= 1;
-
- vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
- vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
- m <<= 1;
-
- vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
- vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
- m <<= 1;
-
- // load Q8 and take product with Q3
- vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
- vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
- vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
- vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
-
- vl = 16;
-
- // retrieve lane to multiply with scale
- vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
- vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
- vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
- vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
- vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
- vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
- vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
- vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
-
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
-
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
-
- q3 += 32; q8 += 128; scale += 8;
-
- }
-
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
-
- sumf += d*sum_t;
-
- }
-
- *s = sumf;
-
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0x3);
- const vector signed char v1 = vec_splats((signed char)0x1);
- const vector unsigned char v2 = vec_splats((unsigned char)0x2);
- const vector unsigned char v3 = vec_splats((unsigned char)0x3);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
- const vector unsigned char v6 = vec_splats((unsigned char)0x6);
- const vector signed char off = vec_splats((signed char)0x20);
-
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
-
- for (int i = 0; i < nb; ++i) {
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd = vec_mul(vxd, vyd);
-
- uint32_t aux[3];
- uint32_t utmp[4];
-
- memcpy(aux, x[i].scales, 12);
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
-
- vector signed char vscales = (vector signed char)vec_xl( 0, utmp);
- vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
- vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
-
- vscales = vec_sub(vscales, off);
-
- vector signed int vsumi0 = vec_splats((int32_t)0);
- vector signed int vsumi1 = vec_splats((int32_t)0);
- vector signed int vsumi2 = vec_splats((int32_t)0);
- vector signed int vsumi3 = vec_splats((int32_t)0);
- vector signed int vsumi4 = vec_splats((int32_t)0);
- vector signed int vsumi5 = vec_splats((int32_t)0);
- vector signed int vsumi6 = vec_splats((int32_t)0);
- vector signed int vsumi7 = vec_splats((int32_t)0);
-
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- for (int j = 0; j < QK_K/128; ++j) {
- __builtin_prefetch(q3, 0, 1);
- __builtin_prefetch(q8, 0, 1);
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
- vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
- q3 += 32;
-
- //the low 2 bits
- vector signed char qxs00 = vec_and(qxs0, lowMask);
- vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
- vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
- vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
- vector signed char qxs10 = vec_and(qxs1, lowMask);
- vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
- vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
- vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
-
- //the 3rd bit
- vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
- vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
- vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
- vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
- vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
- vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
- vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
- vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
- qxhs0 = vec_sr(qxhs0, v4);
- qxhs1 = vec_sr(qxhs1, v4);
-
- vector signed char q3x00 = vec_sub(qxs00, qxh00);
- vector signed char q3x01 = vec_sub(qxs01, qxh01);
- vector signed char q3x02 = vec_sub(qxs02, qxh02);
- vector signed char q3x03 = vec_sub(qxs03, qxh03);
- vector signed char q3x10 = vec_sub(qxs10, qxh10);
- vector signed char q3x11 = vec_sub(qxs11, qxh11);
- vector signed char q3x12 = vec_sub(qxs12, qxh12);
- vector signed char q3x13 = vec_sub(qxs13, qxh13);
-
- vector signed char q8y00 = vec_xl( 0, q8);
- vector signed char q8y10 = vec_xl( 16, q8);
- vector signed char q8y01 = vec_xl( 32, q8);
- vector signed char q8y11 = vec_xl( 48, q8);
- vector signed char q8y02 = vec_xl( 64, q8);
- vector signed char q8y12 = vec_xl( 80, q8);
- vector signed char q8y03 = vec_xl( 96, q8);
- vector signed char q8y13 = vec_xl(112, q8);
- q8 += 128;
-
- vector signed short vscales_h = vec_unpackh(vscales);
- vector signed short vs0 = vec_splat(vscales_h, 0);
- vector signed short vs1 = vec_splat(vscales_h, 1);
- vector signed short vs2 = vec_splat(vscales_h, 2);
- vector signed short vs3 = vec_splat(vscales_h, 3);
- vector signed short vs4 = vec_splat(vscales_h, 4);
- vector signed short vs5 = vec_splat(vscales_h, 5);
- vector signed short vs6 = vec_splat(vscales_h, 6);
- vector signed short vs7 = vec_splat(vscales_h, 7);
- vscales = vec_sld(vscales, vscales, 8);
-
- vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
- vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
- vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
- vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
- vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
- vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
- vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
- vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
-
- vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
- vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
- vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4));
- vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6));
- vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
- vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
- vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5));
- vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7));
-
- vsumi0 = vec_add(vsum0, vsumi0);
- vsumi1 = vec_add(vsum1, vsumi1);
- vsumi2 = vec_add(vsum2, vsumi2);
- vsumi3 = vec_add(vsum3, vsumi3);
- vsumi4 = vec_add(vsum4, vsumi4);
- vsumi5 = vec_add(vsum5, vsumi5);
- vsumi6 = vec_add(vsum6, vsumi6);
- vsumi7 = vec_add(vsum7, vsumi7);
- }
-
- vsumi0 = vec_add(vsumi0, vsumi4);
- vsumi1 = vec_add(vsumi1, vsumi5);
- vsumi2 = vec_add(vsumi2, vsumi6);
- vsumi3 = vec_add(vsumi3, vsumi7);
-
- vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
- vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
- vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
- vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
- }
-
- vsumf0 = vec_add(vsumf0, vsumf2);
- vsumf1 = vec_add(vsumf1, vsumf3);
-
- vsumf0 = vec_add(vsumf0, vsumf1);
-
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
-
- *s = vec_extract(vsumf0, 0);
-
-#elif defined __loongarch_asx
-
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
- const __m128i m32 = __lsx_vreplgr2vr_b(32);
-
- __m256 acc = (__m256)__lasx_xvldi(0);
-
- uint32_t aux[3];
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- // Set up scales
- memcpy(aux, x[i].scales, 12);
- __m128i scales128 = lsx_set_w(
- ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
- ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
- (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
- (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
- scales128 = __lsx_vsub_b(scales128, m32);
- const __m256i all_scales = lasx_ext8_16(scales128);
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
-
- // high bit
- const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
-
- // integer accumulator
- __m256i sumi = __lasx_xvldi(0);
-
- int bit = 0;
- int is = 0;
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- for (int j = 0; j < QK_K/128; ++j) {
- // load low 2 bits
- const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
-
- // prepare low and high bits
- const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
- ++bit;
-
- const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
- ++bit;
-
- // load Q8 quants
- const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
- __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
- __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
- __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
-
- __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
- __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
- __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
- __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
-
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
-
- // multiply with scales
- p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
- p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
- p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
- p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
-
- // accumulate
- p16_0 = __lasx_xvadd_w(p16_0, p16_1);
- p16_2 = __lasx_xvadd_w(p16_2, p16_3);
- sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
- }
- // multiply with block scale and accumulate
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
- }
-
- *s = hsum_float_8(acc);
-
-#else
- // scalar version
- // This function is written like this so the compiler can manage to vectorize most of it
- // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
- // manually vectorized version above. Every other version I tried would run at least 4 times slower.
- // The ideal situation would be if we could just write the code once, and the compiler would
- // automatically produce the best possible set of machine instructions, instead of us having to manually
- // write vectorized versions for AVX, ARM_NEON, etc.
-
- int8_t aux8[QK_K];
- int16_t aux16[8];
- float sums [8];
- int32_t aux32[8];
- memset(sums, 0, 8*sizeof(float));
-
- uint32_t auxs[4];
- const int8_t * scales = (const int8_t*)auxs;
-
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q3 = x[i].qs;
- const uint8_t * restrict hm = x[i].hmask;
- const int8_t * restrict q8 = y[i].qs;
- memset(aux32, 0, 8*sizeof(int32_t));
- int8_t * restrict a = aux8;
- uint8_t m = 1;
- for (int j = 0; j < QK_K; j += 128) {
- for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
- for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
- a += 32; m <<= 1;
- for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
- for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
- a += 32; m <<= 1;
- for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
- for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
- a += 32; m <<= 1;
- for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
- for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
- a += 32; m <<= 1;
- q3 += 32;
- }
- a = aux8;
-
- memcpy(auxs, x[i].scales, 12);
- uint32_t tmp = auxs[2];
- auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
- auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
- auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
- auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
- for (int j = 0; j < QK_K/16; ++j) {
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
- q8 += 8; a += 8;
- }
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
- }
- for (int l = 0; l < 8; ++l) sumf += sums[l];
- *s = sumf;
-
-#endif
-
-}
-
-#else
-
-void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- const block_q3_K * restrict x = vx;
- const block_q8_K * restrict y = vy;
-
- const int nb = n / QK_K;
-
-#ifdef __ARM_NEON
- const int32x4_t vzero = vdupq_n_s32(0);
-
- const uint8x16_t m3b = vdupq_n_u8(0x3);
- const uint8x16_t mh = vdupq_n_u8(4);
-
- ggml_int8x16x4_t q3bytes;
-
- uint16_t aux16[2];
- int8_t * scales = (int8_t *)aux16;
-
- float sum = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- ggml_uint8x16x4_t q3h;
-
- const uint8x8_t hbits = vld1_u8(x[i].hmask);
- const uint8x16_t q3bits = vld1q_u8(x[i].qs);
- const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
-
- const uint16_t a = *(const uint16_t *)x[i].scales;
- aux16[0] = a & 0x0f0f;
- aux16[1] = (a >> 4) & 0x0f0f;
-
- for (int j = 0; j < 4; ++j) scales[j] -= 8;
-
- int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
- q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
- q3h.val[1] = vandq_u8(mh, htmp);
- q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
- q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
-
- q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
- q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
- q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
- q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
-
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
-
- sum += d * isum;
-
- }
-
- *s = sum;
-
-#elif defined __AVX2__
-
- const __m256i m3 = _mm256_set1_epi8(3);
- const __m256i m1 = _mm256_set1_epi8(1);
-
- __m256 acc = _mm256_setzero_ps();
-
- uint64_t aux64;
-
- uint16_t aux16[2];
- const int8_t * aux8 = (const int8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint16_t a = *(const uint16_t *)x[i].scales;
- aux16[0] = a & 0x0f0f;
- aux16[1] = (a >> 4) & 0x0f0f;
-
- const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
- const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
-
- memcpy(&aux64, x[i].hmask, 8);
-
- const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
- __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
- __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
- q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
- q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
-
- // load low 2 bits
- const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
-
- // prepare low and high bits
- const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
- const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
- const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
-
- // load Q8 quants
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
- const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
-
- __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
- __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
-
- p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
-
- // multiply with scales
- p16_0 = _mm256_madd_epi16(scale_0, p16_0);
- p16_1 = _mm256_madd_epi16(scale_1, p16_1);
-
- p16_0 = _mm256_add_epi32(p16_0, p16_1);
-
- // multiply with block scale and accumulate
- acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
-
- }
-
- *s = hsum_float_8(acc);
-
-#elif defined __AVX__
-
- const __m128i m3 = _mm_set1_epi8(3);
- const __m128i m1 = _mm_set1_epi8(1);
-
- __m256 acc = _mm256_setzero_ps();
-
- uint64_t aux64;
-
- uint16_t aux16[2];
- const int8_t * aux8 = (const int8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint16_t a = *(const uint16_t *)x[i].scales;
- aux16[0] = a & 0x0f0f;
- aux16[1] = (a >> 4) & 0x0f0f;
-
- const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
- const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
- const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
- const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
-
- memcpy(&aux64, x[i].hmask, 8);
-
- __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
- __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
- __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
- __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
- q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
- q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
- q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
- q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
-
- // load low 2 bits
- const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
-
- // prepare low and high bits
- const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
- const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
- const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
- const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
-
- // load Q8 quants
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
- const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
- const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
- const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
-
- __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
- __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
- __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
- __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
-
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
-
- // multiply with scales
- p16_0 = _mm_madd_epi16(scale_0, p16_0);
- p16_1 = _mm_madd_epi16(scale_1, p16_1);
- p16_2 = _mm_madd_epi16(scale_2, p16_2);
- p16_3 = _mm_madd_epi16(scale_3, p16_3);
-
- p16_0 = _mm_add_epi32(p16_0, p16_2);
- p16_1 = _mm_add_epi32(p16_1, p16_3);
- __m256i p16 = MM256_SET_M128I(p16_1, p16_0);
-
- // multiply with block scale and accumulate
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
-
- }
-
- *s = hsum_float_8(acc);
-
-#elif defined __riscv_v_intrinsic
-
- uint16_t aux16[2];
- int8_t * scales = (int8_t *)aux16;
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint16_t a = *(const uint16_t *)x[i].scales;
- aux16[0] = a & 0x0f0f;
- aux16[1] = (a >> 4) & 0x0f0f;
-
- for (int j = 0; j < 4; ++j) scales[j] -= 8;
-
- int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
-
- // load qh
- vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8);
- vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
-
- size_t vl = 16;
-
- // extend and combine both qh_x1 and qh_x2
- vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
-
- vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
- vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl);
- vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
- vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl);
-
- // load Q3
- vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl);
-
- vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl);
- vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl);
- vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl);
- vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl);
-
- vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0);
- vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1);
- vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2);
- vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3);
-
- // load Q8 and take product with Q3
- vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
- vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
- vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
- vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
-
- vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
- vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
- vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
- vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
-
- isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3];
-
- sumf += d * isum;
-
- }
-
- *s = sumf;
-
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0x3);
- const vector signed char v1 = vec_splats((signed char)0x1);
- const vector unsigned char v2 = vec_splats((unsigned char)0x2);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
- const vector unsigned char v6 = vec_splats((unsigned char)0x6);
- const vector signed char off = vec_splats((signed char)0x8);
-
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
-
-#pragma GCC unroll 2
- for (int i = 0; i < nb; ++i) {
- __builtin_prefetch(x[i].qs, 0, 1);
- __builtin_prefetch(y[i].qs, 0, 1);
-
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd = vec_mul(vxd, vyd);
-
- uint16_t aux16[2];
- int8_t * scales = (int8_t *)aux16;
-
- const uint16_t a = *(const uint16_t *)x[i].scales;
- aux16[0] = a & 0x0f0f;
- aux16[1] = (a >> 4) & 0x0f0f;
-
- vector signed char vscales = (vector signed char)vec_xl_len(scales, 8);
- vector signed char qxhs0 = (vector signed char)vec_xl_len(x[i].hmask, 8);
- qxhs0 = vec_or(qxhs0, vec_sr(vec_sld(qxhs0, qxhs0, 8), (vector unsigned char)v1));
-
- vscales = vec_sub(vscales, off);
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs);
- vector signed char qxs00 = vec_and(qxs0, lowMask);
- vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
- vector signed char qxs10 = vec_and(vec_sr(qxs0, v4), lowMask);
- vector signed char qxs11 = vec_and(vec_sr(qxs0, v6), lowMask);
-
- //the 3rd bit
- vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
- vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
- vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v4)), v2);
- vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v6)), v2);
- qxhs0 = vec_sr(qxhs0, v4);
-
- vector signed char q3x00 = vec_sub(qxs00, qxh00);
- vector signed char q3x01 = vec_sub(qxs01, qxh01);
- vector signed char q3x10 = vec_sub(qxs10, qxh02);
- vector signed char q3x11 = vec_sub(qxs11, qxh03);
-
- vector signed char q8y00 = vec_xl( 0, y[i].qs);
- vector signed char q8y01 = vec_xl( 16, y[i].qs);
- vector signed char q8y10 = vec_xl( 32, y[i].qs);
- vector signed char q8y11 = vec_xl( 48, y[i].qs);
-
- vector signed short vscales_h = vec_unpackh(vscales);
- vector signed short vs0 = vec_splat(vscales_h, 0);
- vector signed short vs1 = vec_splat(vscales_h, 1);
- vector signed short vs2 = vec_splat(vscales_h, 2);
- vector signed short vs3 = vec_splat(vscales_h, 3);
-
- vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
- vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
- vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
- vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
-
- vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
- vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
- vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
- vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
-
- vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
- vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
- vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
- vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
- }
-
- vsumf0 = vec_add(vsumf0, vsumf2);
- vsumf1 = vec_add(vsumf1, vsumf3);
-
- vsumf0 = vec_add(vsumf0, vsumf1);
-
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
-
- *s = vec_extract(vsumf0, 0);
-
-#elif defined __loongarch_asx
-
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
- const __m256i m1 = __lasx_xvreplgr2vr_b(1);
-
- __m256 acc = (__m256)__lasx_xvldi(0);
-
- uint64_t aux64;
-
- uint16_t aux16[2];
- const int8_t * aux8 = (const int8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
-
- const uint8_t * restrict q3 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
- const __m256i scale_0 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8));
- const __m256i scale_1 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8));
-
- memcpy(&aux64, x[i].hmask, 8);
-
- __m128i haux = __lsx_vinsgr2vr_d(haux, aux64, 0);
- haux = __lsx_vinsgr2vr_d(haux, aux64 >> 1, 1);
- __m256i q3h_0 = lasx_insertf128(__lsx_vsrli_h(haux, 2), haux);
- __m256i q3h_1 = __lasx_xvsrli_h(q3h_0, 4);
- q3h_0 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_0, m1), 2);
- q3h_1 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_1, m1), 2);
-
- // load low 2 bits
- const __m128i q3bits = __lsx_vld((const __m128i*)q3, 0);
-
- // prepare low and high bits
- const __m256i q3aux = lasx_insertf128(__lsx_vsrli_h(q3bits, 2), q3bits);
- const __m256i q3l_0 = __lasx_xvand_v(q3aux, m3);
- const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3aux, 4), m3);
-
- // load Q8 quants
- const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
- const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
-
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
- // and 2 if the high bit was set)
- const __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
- const __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
-
- __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
- __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
-
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
-
- // multiply with scales
- p16_0 = lasx_madd_h(scale_0, p16_0);
- p16_1 = lasx_madd_h(scale_1, p16_1);
-
- p16_0 = __lasx_xvadd_w(p16_0, p16_1);
-
- // multiply with block scale and accumulate
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(p16_0), acc);
- }
-
- *s = hsum_float_8(acc);
-
-#else
-
- int8_t aux8[QK_K];
- int16_t aux16[8];
- float sums [8];
- int32_t aux32[8];
- int32_t scales[4];
- memset(sums, 0, 8*sizeof(float));
-
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q3 = x[i].qs;
- const uint8_t * restrict hm = x[i].hmask;
- const int8_t * restrict q8 = y[i].qs;
- int8_t * restrict a = aux8;
- for (int l = 0; l < 8; ++l) {
- a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
- a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
- a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
- a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
- a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
- a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
- a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
- a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
- }
-
- scales[0] = (x[i].scales[0] & 0xF) - 8;
- scales[1] = (x[i].scales[0] >> 4) - 8;
- scales[2] = (x[i].scales[1] & 0xF) - 8;
- scales[3] = (x[i].scales[1] >> 4) - 8;
-
- memset(aux32, 0, 8*sizeof(int32_t));
- for (int j = 0; j < QK_K/16; ++j) {
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
- }
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
- }
- for (int l = 0; l < 8; ++l) sumf += sums[l];
- *s = sumf;
-
-#endif
-
-}
-#endif
-
-#if QK_K == 256
-void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- const block_q4_K * restrict x = vx;
- const block_q8_K * restrict y = vy;
-
- const int nb = n / QK_K;
-
- static const uint32_t kmask1 = 0x3f3f3f3f;
- static const uint32_t kmask2 = 0x0f0f0f0f;
- static const uint32_t kmask3 = 0x03030303;
-
- uint32_t utmp[4];
-
-#ifdef __ARM_NEON
- const uint8x16_t m4b = vdupq_n_u8(0xf);
- const int32x4_t mzero = vdupq_n_s32(0);
-
- ggml_int8x16x2_t q4bytes;
- ggml_int8x16x2_t q8bytes;
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
-
- memcpy(utmp, x[i].scales, 12);
-
- uint32x2_t mins8 = { 0 };
- mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
- mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
-
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[0] &= kmask1;
-
- const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
- const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
- vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
- sumf -= dmin * vaddvq_s32(prod);
-
- const uint8_t * scales = (const uint8_t *)utmp;
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- int32_t sumi1 = 0;
- int32_t sumi2 = 0;
-
- for (int j = 0; j < QK_K/64; ++j) {
- const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
-
- q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
-
- const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
- sumi1 += vaddvq_s32(p1) * scales[2*j+0];
-
- q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
-
- const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
-
- sumi2 += vaddvq_s32(p2) * scales[2*j+1];
- }
-
- sumf += d * (sumi1 + sumi2);
-
- }
-
- *s = sumf;
-
-#elif defined __AVX2__
-
- const __m256i m4 = _mm256_set1_epi8(0xF);
-
- __m256 acc = _mm256_setzero_ps();
- __m128 acc_m = _mm_setzero_ps();
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
-
- const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
- const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
- const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
- acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
-
- const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
- const __m256i scales = MM256_SET_M128I(sc128, sc128);
-
- __m256i sumi = _mm256_setzero_si256();
-
- for (int j = 0; j < QK_K/64; ++j) {
-
- const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
- const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
-
- const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
- const __m256i q4l = _mm256_and_si256(q4bits, m4);
- const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
-
- const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
- p16l = _mm256_madd_epi16(scale_l, p16l);
-
- const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
- p16h = _mm256_madd_epi16(scale_h, p16h);
- const __m256i sumj = _mm256_add_epi32(p16l, p16h);
-
- sumi = _mm256_add_epi32(sumi, sumj);
- }
-
- __m256 vd = _mm256_set1_ps(d);
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
-
- }
-
- acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
- acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
-
- *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
-
-#elif defined __AVX__
-
- const __m128i m4 = _mm_set1_epi8(0xF);
- const __m128i m2 = _mm_set1_epi8(0x2);
-
- __m256 acc = _mm256_setzero_ps();
- __m128 acc_m = _mm_setzero_ps();
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
- const __m128i scales = _mm_cvtepu8_epi16(utmps);
- const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
-
- const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
- const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
- const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
- const __m128i prod = _mm_madd_epi16(mins, q8s);
- acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
-
- __m128i sumi_0 = _mm_setzero_si128();
- __m128i sumi_1 = _mm_setzero_si128();
-
- __m128i shuffle = _mm_set1_epi16(0x0100);
- for (int j = 0; j < QK_K/64; ++j) {
-
- const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi16(shuffle, m2);
- const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi16(shuffle, m2);
-
- __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
- const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
- q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
- const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
-
- const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
- p16l = _mm_madd_epi16(scale_l, p16l);
- sumi_0 = _mm_add_epi32(sumi_0, p16l);
- const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
- p16l = _mm_madd_epi16(scale_l, p16l);
- sumi_1 = _mm_add_epi32(sumi_1, p16l);
-
- const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
- p16h = _mm_madd_epi16(scale_h, p16h);
- sumi_0 = _mm_add_epi32(sumi_0, p16h);
- const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
- p16h = _mm_madd_epi16(scale_h, p16h);
- sumi_1 = _mm_add_epi32(sumi_1, p16h);
-
- }
-
- __m256 vd = _mm256_set1_ps(d);
- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
- acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
-
- }
-
- acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
- acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
-
- *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
-
-#elif defined __riscv_v_intrinsic
-
- const uint8_t * scales = (const uint8_t*)&utmp[0];
- const uint8_t * mins = (const uint8_t*)&utmp[2];
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- size_t vl = 8;
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
-
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
-
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
- sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- vl = 32;
-
- int32_t sum_1 = 0;
- int32_t sum_2 = 0;
-
- vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
-
- for (int j = 0; j < QK_K/64; ++j) {
- // load Q4
- vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
-
- // load Q8 and multiply it with lower Q4 nibble
- vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
- vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
- vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
- vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
-
- sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
-
- // load Q8 and multiply it with upper Q4 nibble
- vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
- vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
- vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
- vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
-
- sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
-
- q4 += 32; q8 += 64;
-
- }
-
- sumf += d*(sum_1 + sum_2);
-
- }
-
- *s = sumf;
-
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0xF);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
-
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
-
- for (int i = 0; i < nb; ++i) {
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd = vec_mul(vxd, vyd);
-
- vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
- vector float vdmin = vec_mul(vxmin, vyd);
-
- vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
- vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
-
- memcpy(utmp, x[i].scales, 12);
-
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
- vector signed short vscales = vec_unpackh(utmps);
- vector signed short q4xmins = vec_unpackl(utmps);
- vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
- vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
-
- vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
- vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
- vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
- vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
-
- vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
- vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
- vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
- vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
-
- vector signed int vsumi0 = vec_splats((int32_t)0);
- vector signed int vsumi1 = vec_splats((int32_t)0);
- vector signed int vsumi2 = vec_splats((int32_t)0);
- vector signed int vsumi3 = vec_splats((int32_t)0);
- vector signed int vsumi4 = vec_splats((int32_t)0);
- vector signed int vsumi5 = vec_splats((int32_t)0);
- vector signed int vsumi6 = vec_splats((int32_t)0);
- vector signed int vsumi7 = vec_splats((int32_t)0);
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- for (int j = 0; j < QK_K/64; j+=2) {
- __builtin_prefetch(q4, 0, 1);
- __builtin_prefetch(q8, 0, 1);
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
- vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
- vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
- vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
- q4 += 64;
-
- vector signed char q4x00 = vec_and(qxs0, lowMask);
- vector signed char q4x01 = vec_sr(qxs0, v4);
- vector signed char q4x10 = vec_and(qxs1, lowMask);
- vector signed char q4x11 = vec_sr(qxs1, v4);
- vector signed char q4x20 = vec_and(qxs2, lowMask);
- vector signed char q4x21 = vec_sr(qxs2, v4);
- vector signed char q4x30 = vec_and(qxs3, lowMask);
- vector signed char q4x31 = vec_sr(qxs3, v4);
-
- vector signed char q8y00 = vec_xl( 0, q8);
- vector signed char q8y10 = vec_xl( 16, q8);
- vector signed char q8y01 = vec_xl( 32, q8);
- vector signed char q8y11 = vec_xl( 48, q8);
- vector signed char q8y20 = vec_xl( 64, q8);
- vector signed char q8y30 = vec_xl( 80, q8);
- vector signed char q8y21 = vec_xl( 96, q8);
- vector signed char q8y31 = vec_xl(112, q8);
- q8 += 128;
-
- vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00));
- vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01));
- vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10));
- vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11));
- vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20));
- vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21));
- vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30));
- vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31));
-
- vector signed short vs0 = vec_splat(vscales, 0);
- vector signed short vs1 = vec_splat(vscales, 1);
- vector signed short vs2 = vec_splat(vscales, 2);
- vector signed short vs3 = vec_splat(vscales, 3);
- vscales = vec_sld(vscales, vscales, 8);
-
- qv00 = vec_add(qv00, qv10);
- qv10 = vec_add(qv01, qv11);
- qv20 = vec_add(qv20, qv30);
- qv30 = vec_add(qv21, qv31);
-
- vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
- vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
- vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2);
- vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3);
- vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4);
- vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5);
- vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6);
- vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7);
- }
-
- vsumi0 = vec_add(vsumi0, vsumi4);
- vsumi1 = vec_add(vsumi1, vsumi5);
- vsumi2 = vec_add(vsumi2, vsumi6);
- vsumi3 = vec_add(vsumi3, vsumi7);
-
- vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
- vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
- vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
- vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
- }
-
- vsumf0 = vec_add(vsumf0, vsumf2);
- vsumf1 = vec_add(vsumf1, vsumf3);
-
- vsumf0 = vec_add(vsumf0, vsumf1);
-
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
-
- *s = vec_extract(vsumf0, 0);
-
-#elif defined __loongarch_asx
-
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-
- __m256 acc = (__m256)__lasx_xvldi(0);
- __m128 acc_m = (__m128)__lsx_vldi(0);
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- memcpy(utmp, x[i].scales, 12);
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
-
- const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
- const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
- acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
-
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
- const __m256i scales = lasx_insertf128(sc128, sc128);
-
- __m256i sumi = __lasx_xvldi(0);
-
- for (int j = 0; j < QK_K/64; ++j) {
-
- const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
- const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
-
- const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
- const __m256i q4l = __lasx_xvand_v(q4bits, m4);
- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
-
- const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- __m256i p16l = lasx_maddubs_h(q4l, q8l);
- p16l = lasx_madd_h(scale_l, p16l);
-
- const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- __m256i p16h = lasx_maddubs_h(q4h, q8h);
- p16h = lasx_madd_h(scale_h, p16h);
- const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
-
- sumi = __lasx_xvadd_w(sumi, sumj);
- }
-
- __m256 vd = __lasx_xvreplfr2vr_s(d);
- acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
- }
-
- acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
- __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
- acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
-
- ft_union fi;
- fi.i = __lsx_vpickve2gr_w(acc_m, 0);
- *s = hsum_float_8(acc) + fi.f ;
-
-#else
-
- const uint8_t * scales = (const uint8_t*)&utmp[0];
- const uint8_t * mins = (const uint8_t*)&utmp[2];
-
- int8_t aux8[QK_K];
- int16_t aux16[8];
- float sums [8];
- int32_t aux32[8];
- memset(sums, 0, 8*sizeof(float));
-
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
- memset(aux32, 0, 8*sizeof(int32_t));
- int8_t * restrict a = aux8;
- for (int j = 0; j < QK_K/64; ++j) {
- for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
- a += 32;
- for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
- a += 32; q4 += 32;
- }
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- int sumi = 0;
- for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
- a = aux8;
- int is = 0;
- for (int j = 0; j < QK_K/32; ++j) {
- int32_t scale = scales[is++];
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
- }
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
- const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
- sumf -= dmin * sumi;
- }
- for (int l = 0; l < 8; ++l) sumf += sums[l];
- *s = sumf;
-#endif
-}
-#else
-void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- const block_q4_K * restrict x = vx;
- const block_q8_K * restrict y = vy;
-
- const int nb = n / QK_K;
-
-#ifdef __ARM_NEON
- const uint8x16_t m4b = vdupq_n_u8(0xf);
-
- const int32x4_t mzero = vdupq_n_s32(0);
-
- float sumf = 0;
-
- ggml_int8x16x2_t q4bytes;
- ggml_int8x16x4_t q8bytes;
-
- float sum_mins = 0.f;
-
- uint16_t aux16[2];
- const uint8_t * restrict scales = (const uint8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint16_t * restrict a = (const uint16_t *)x[i].scales;
- aux16[0] = a[0] & 0x0f0f;
- aux16[1] = (a[0] >> 4) & 0x0f0f;
-
- const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
- sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi;
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
-
- const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
-
- q8bytes = ggml_vld1q_s8_x4(q8);
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
-
- const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
- const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
-
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
-
- const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
- const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
-
- sumf += d * (sumi1 + sumi2);
- }
-
- *s = sumf - sum_mins;
-
-#elif defined __AVX2__
-
- const __m256i m4 = _mm256_set1_epi8(0xF);
-
- __m256 acc = _mm256_setzero_ps();
-
- float summs = 0;
-
- uint16_t aux16[2];
- const uint8_t * scales = (const uint8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
- const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
- const __m256 vd = _mm256_set1_ps(d);
-
- const uint16_t * a = (const uint16_t *)x[i].scales;
- aux16[0] = a[0] & 0x0f0f;
- aux16[1] = (a[0] >> 4) & 0x0f0f;
-
- summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
- const __m256i q4l = _mm256_and_si256(q4bits, m4);
- const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
-
- const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
-
- const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
- const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
-
- const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
-
- const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
-
- }
-
- *s = hsum_float_8(acc) - summs;
-
-#elif defined __AVX__
-
- const __m128i m4 = _mm_set1_epi8(0xF);
-
- __m256 acc = _mm256_setzero_ps();
-
- float summs = 0;
-
- uint16_t aux16[2];
- const uint8_t * scales = (const uint8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
- const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
- const __m256 vd = _mm256_set1_ps(d);
-
- const uint16_t * a = (const uint16_t *)x[i].scales;
- aux16[0] = a[0] & 0x0f0f;
- aux16[1] = (a[0] >> 4) & 0x0f0f;
-
- summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
- const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
- const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
- const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
- const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
- const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
- const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
-
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
-
- const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
- const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
- const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
- const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
-
- const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
- const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
- acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
-
- const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
- const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
- acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
-
- }
-
- *s = hsum_float_8(acc) - summs;
-
-#elif defined __riscv_v_intrinsic
-
- uint16_t s16[2];
- const uint8_t * restrict scales = (const uint8_t *)s16;
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const uint16_t * restrict b = (const uint16_t *)x[i].scales;
- s16[0] = b[0] & 0x0f0f;
- s16[1] = (b[0] >> 4) & 0x0f0f;
-
- sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
-
- size_t vl = 32;
-
- vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
-
- // load Q4
- vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
-
- // load Q8 and multiply it with lower Q4 nibble
- vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
- vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl);
- vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl);
-
- sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1);
-
- // load Q8 and multiply it with upper Q4 nibble
- vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
- vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl);
- vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl);
-
- sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2);
-
- }
-
- *s = sumf;
-
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0xF);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
-
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
-
-#pragma GCC unroll 2
- for (int i = 0; i < nb; ++i) {
- __builtin_prefetch(x[i].qs, 0, 1);
- __builtin_prefetch(y[i].qs, 0, 1);
-
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d[1]));
- vector float vyd = vec_splats(y[i].d);
- vector float vd= vec_mul(vxd, vyd);
-
- uint16_t s16[2];
- const uint8_t * scales = (const uint8_t *)s16;
-
- const uint16_t * restrict b = (const uint16_t *)x[i].scales;
- s16[0] = b[0] & 0x0f0f;
- s16[1] = (b[0] >> 4) & 0x0f0f;
-
- vector signed char utmps = (vector signed char)vec_xl_len(scales, 4);
- vector signed short vscales = (vector signed short)vec_unpackh(utmps);
- vector signed short q4xmins0 = vec_mergeh(vscales, vscales);
- q4xmins0 = vec_sld(q4xmins0, q4xmins0, 8);
-
- vector signed short q8ysums0 = vec_xl_len((const int16_t *)(y[i].bsums), 8);
-
- vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
- vector signed int prod1 = vec_mulo(q4xmins0, q8ysums0);
-
- vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vd, vsumf0);
- vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vd, vsumf1);
-
- vd = vec_mul(vyd, vec_splats(GGML_FP16_TO_FP32(x[i].d[0])));
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs);
- vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs);
- vector signed char q4x00 = vec_and(qxs0, lowMask);
- vector signed char q4x01 = vec_sr(qxs0, v4);
- vector signed char q4x10 = vec_and(qxs1, lowMask);
- vector signed char q4x11 = vec_sr(qxs1, v4);
-
- vector signed char q8y00 = vec_xl( 0, y[i].qs);
- vector signed char q8y10 = vec_xl(16, y[i].qs);
- vector signed char q8y01 = vec_xl(32, y[i].qs);
- vector signed char q8y11 = vec_xl(48, y[i].qs);
-
- vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00));
- vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01));
- vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10));
- vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11));
-
- vector signed short vs0 = vec_splat(vscales, 0);
- vector signed short vs1 = vec_splat(vscales, 1);
-
- vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
- vector signed int vsumi1 = vec_add(vec_mule(qv10, vs0), vec_mulo(qv10, vs0));
- vector signed int vsumi2 = vec_add(vec_mule(qv01, vs1), vec_mulo(qv01, vs1));
- vector signed int vsumi3 = vec_add(vec_mule(qv11, vs1), vec_mulo(qv11, vs1));
-
- vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
- vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
- vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
- vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
- }
-
- vsumf0 = vec_add(vsumf0, vsumf2);
- vsumf1 = vec_add(vsumf1, vsumf3);
-
- vsumf0 = vec_add(vsumf0, vsumf1);
-
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
- vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
-
- *s = vec_extract(vsumf0, 0);
-
-#elif defined __loongarch_asx
-
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-
- __m256 acc = (__m256)__lasx_xvldi(0);
-
- float summs = 0;
-
- uint16_t aux16[2];
- const uint8_t * scales = (const uint8_t *)aux16;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
- const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
- const __m256 vd = __lasx_xvreplfr2vr_s(d);
-
- const uint16_t * a = (const uint16_t *)x[i].scales;
- aux16[0] = a[0] & 0x0f0f;
- aux16[1] = (a[0] >> 4) & 0x0f0f;
-
- summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0);
- const __m256i q4l = __lasx_xvand_v(q4bits, m4);
- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
-
- const __m256i q8l = __lasx_xvld((const __m256i*)(q8+ 0), 0);
- const __m256i q8h = __lasx_xvld((const __m256i*)(q8+32), 0);
-
- const __m256i p16l = lasx_maddubs_h(q4l, q8l);
- const __m256i p16h = lasx_maddubs_h(q4h, q8h);
-
- const __m256i p32l = lasx_madd_h(__lasx_xvreplgr2vr_h(scales[0]), p16l);
- acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(p32l), acc);
-
- const __m256i p32h = lasx_madd_h(__lasx_xvreplgr2vr_h(scales[1]), p16h);
- acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(p32h), acc);
- }
-
- *s = hsum_float_8(acc) - summs;
-
-#else
-
- uint8_t aux8[QK_K];
- int16_t aux16[16];
- float sums [8];
- memset(sums, 0, 8*sizeof(float));
-
- uint16_t s16[2];
- const uint8_t * restrict scales = (const uint8_t *)s16;
-
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
- uint8_t * restrict a = aux8;
- for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
- for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
-
- const uint16_t * restrict b = (const uint16_t *)x[i].scales;
- s16[0] = b[0] & 0x0f0f;
- s16[1] = (b[0] >> 4) & 0x0f0f;
-
- sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
-
- for (int j = 0; j < QK_K/32; ++j) {
- for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
- q8 += 16; a += 16;
- for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
- q8 += 16; a += 16;
- const float dl = d * scales[j];
- for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
- }
- }
- for (int l = 0; l < 8; ++l) sumf += sums[l];
- *s = sumf;
-#endif
-}
-#endif
-
-#if QK_K == 256
-void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
- assert(n % QK_K == 0);
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- const block_q5_K * restrict x = vx;
- const block_q8_K * restrict y = vy;
-
- const int nb = n / QK_K;
-
- static const uint32_t kmask1 = 0x3f3f3f3f;
- static const uint32_t kmask2 = 0x0f0f0f0f;
- static const uint32_t kmask3 = 0x03030303;
-
- uint32_t utmp[4];
-
-#ifdef __ARM_NEON
- const uint8x16_t m4b = vdupq_n_u8(0xf);
- const uint8x16_t mone = vdupq_n_u8(1);
- const uint8x16_t mtwo = vdupq_n_u8(2);
- const int32x4_t mzero = vdupq_n_s32(0);
-
- ggml_int8x16x4_t q5bytes;
-
- float sumf = 0;
-
- for (int i = 0; i < nb; ++i) {
-
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
-
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
- const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
- const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
- vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
- int32_t sumi_mins = vaddvq_s32(prod);
-
- const uint8_t * scales = (const uint8_t *)utmp;
-
- const uint8_t * restrict q5 = x[i].qs;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
-
- ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
-
- ggml_uint8x16x4_t q5h;
-
- int32_t sumi = 0;
-
- for (int j = 0; j < QK_K/64; ++j) {
-
- const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
- const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
-
- q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
- q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
- q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
- q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
- qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
- qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
-
- q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
- q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
- q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
- q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
-
- sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
- sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
- }
-
- sumf += d * sumi - dmin * sumi_mins;
- }
-
- *s = sumf;
-
-#elif defined __AVX2__
-
- const __m256i m4 = _mm256_set1_epi8(0xF);
- const __m128i mzero = _mm_setzero_si128();
- const __m256i mone = _mm256_set1_epi8(1);
-
- __m256 acc = _mm256_setzero_ps();
-
- float summs = 0.f;
-
- for (int i = 0; i < nb; ++i) {
-
- const uint8_t * restrict q5 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
-#if QK_K == 256
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-#else
- // TODO
- const float d = 0, dmin = 0;
-#endif
-
- const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
-
- const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
- const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
- const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
- const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
- summs += dmin * _mm_extract_epi32(hsum, 0);
-
- const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
- const __m256i scales = MM256_SET_M128I(sc128, sc128);
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ __m128i scales128 = _mm_set_epi32(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = _mm_sub_epi8(scales128, m32);
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
- const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
- __m256i hmask = mone;
+ // high bit
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+ // integer accumulator
__m256i sumi = _mm256_setzero_si256();
int bit = 0;
+ int is = 0;
- for (int j = 0; j < QK_K/64; ++j) {
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits
+ const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
- const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
- const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+ // prepare low and high bits
+ const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+ const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
- const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+ const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
- const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
- const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
- const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
- hmask = _mm256_slli_epi16(hmask, 1);
+ const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+ const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
- const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
- const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
- const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
- hmask = _mm256_slli_epi16(hmask, 1);
+ const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+ const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
+ // load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
- __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+ __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+ __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+ __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
- p16_0 = _mm256_madd_epi16(scale_0, p16_0);
- p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+ __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+ __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+ // multiply with scales
+ p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+ p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+ p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+ p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+ // accumulate
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
+ p16_2 = _mm256_add_epi32(p16_2, p16_3);
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
- __m256 vd = _mm256_set1_ps(d);
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+ // multiply with block scale and accumulate
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
- *s = hsum_float_8(acc) + summs;
+ *s = hsum_float_8(acc);
#elif defined __AVX__
- const __m128i m4 = _mm_set1_epi8(0xF);
- const __m128i mzero = _mm_setzero_si128();
- const __m128i mone = _mm_set1_epi8(1);
+ const __m128i m3 = _mm_set1_epi8(3);
+ const __m128i mone = _mm_set1_epi8(1);
+ const __m128i m32 = _mm_set1_epi8(32);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
- float summs = 0.f;
+ const uint32_t *aux;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
- const __m128i scales = _mm_cvtepu8_epi16(utmps);
- const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
-
- const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
- const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
- const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
- const __m128i prod = _mm_madd_epi16(mins, q8s);
- const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
- summs += dmin * _mm_extract_epi32(hsum, 0);
+ // Set up scales
+ aux = (const uint32_t *)x[i].scales;
+ __m128i scales128 = _mm_set_epi32(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = _mm_sub_epi8(scales128, m32);
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+ const __m128i scales[2] = { scales_0, scales_1 };
- const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
- const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
- __m128i hmask = mone;
+ // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+ // integer accumulator
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
- int bit = 0;
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+ const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+ const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
- __m128i shuffle = _mm_set1_epi16(0x0100);
- for (int j = 0; j < QK_K/64; ++j) {
+ // prepare low and high bits
+ const int bit = j << 2;
- const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi16(shuffle, m2);
- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi16(shuffle, m2);
+ const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+ const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+ const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+ const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
- const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
- const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+ const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+ const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
- __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
- __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
- __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
- __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
- __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
- __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
- hmask = _mm_slli_epi16(hmask, 1);
+ const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+ const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+ const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+ const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
- __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
- __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
- p16_0 = _mm_madd_epi16(scale_0, p16_0);
- p16_1 = _mm_madd_epi16(scale_0, p16_1);
+ const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+ const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+ const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+ const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
- q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
- q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
- q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
- q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
- q5_0 = _mm_add_epi8(q5l_0, q5h_0);
- q5_1 = _mm_add_epi8(q5l_1, q5h_1);
- hmask = _mm_slli_epi16(hmask, 1);
+ // load Q8 quants from block_q8_K.qs[QK_K]
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+ __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+ __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+ __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+ __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+ __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+ __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+ __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+ __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+ __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+ __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+ __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
- q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
- __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
- p16_2 = _mm_madd_epi16(scale_1, p16_2);
- p16_3 = _mm_madd_epi16(scale_1, p16_3);
+ // multiply with scales
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+ // accumulate
+ p16_0 = _mm_add_epi32(p16_0, p16_1);
+ p16_2 = _mm_add_epi32(p16_2, p16_3);
+ p16_4 = _mm_add_epi32(p16_4, p16_5);
+ p16_6 = _mm_add_epi32(p16_6, p16_7);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
}
- __m256 vd = _mm256_set1_ps(d);
+ // multiply with block scale and accumulate
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
- acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
}
- *s = hsum_float_8(acc) + summs;
+ *s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
- const uint8_t * scales = (const uint8_t*)&utmp[0];
- const uint8_t * mins = (const uint8_t*)&utmp[2];
+ uint32_t aux[3];
+ uint32_t utmp[4];
float sumf = 0;
- float sums = 0.0;
+ for (int i = 0; i < nb; ++i) {
- size_t vl;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict qh = x[i].hmask;
+ const int8_t * restrict q8 = y[i].qs;
- for (int i = 0; i < nb; ++i) {
+ memcpy(aux, x[i].scales, 12);
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
- vl = 8;
+ int8_t * scale = (int8_t *)utmp;
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
- const uint8_t * restrict q5 = x[i].qs;
- const uint8_t * restrict hm = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+ size_t vl = 32;
+ uint8_t m = 1;
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
+ int sum_t = 0;
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+ for (int j = 0; j < QK_K; j += 128) {
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
- sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+ vl = 32;
- vl = 32;
- int32_t aux32 = 0;
- int is = 0;
+ // load Q3
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
- uint8_t m = 1;
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
- for (int j = 0; j < QK_K/64; ++j) {
- // load Q5 and Q8
- vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
- vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
- vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+ // compute mask for subtraction
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
+ m <<= 1;
- // compute mask for addition
- vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
- vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
m <<= 1;
- vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
- vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
- vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
m <<= 1;
- vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
- vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
+ m <<= 1;
- vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
- vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+ // load Q8 and take product with Q3
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
- vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
- vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+ vl = 16;
- aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
- q5 += 32; q8 += 64;
+ // retrieve lane to multiply with scale
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+ q3 += 32; q8 += 128; scale += 8;
}
- vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
- sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+ sumf += d*sum_t;
}
- *s = sumf+sums;
+ *s = sumf;
#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0xF);
- const vector unsigned char v1 = vec_splats((unsigned char)0x1);
+ const vector signed char lowMask = vec_splats((signed char)0x3);
+ const vector signed char v1 = vec_splats((signed char)0x1);
const vector unsigned char v2 = vec_splats((unsigned char)0x2);
const vector unsigned char v3 = vec_splats((unsigned char)0x3);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+ const vector signed char off = vec_splats((signed char)0x20);
vector float vsumf0 = vec_splats(0.0f);
vector float vsumf1 = vec_splats(0.0f);
vector float vyd = vec_splats(y[i].d);
vector float vd = vec_mul(vxd, vyd);
- vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
- vector float vdmin = vec_mul(vxmin, vyd);
-
- memcpy(utmp, x[i].scales, 12);
-
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
- vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
-
- vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
- vector signed short vscales = vec_unpackh(utmps);
-
- vector signed short q5xmins = vec_unpackl(utmps);
- vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
- vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+ uint32_t aux[3];
+ uint32_t utmp[4];
- vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
- vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
- vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
- vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+ memcpy(aux, x[i].scales, 12);
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
- vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
- vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
- vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
- vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+ vector signed char vscales = (vector signed char)vec_xl( 0, utmp);
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
- vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
- vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+ vscales = vec_sub(vscales, off);
vector signed int vsumi0 = vec_splats((int32_t)0);
vector signed int vsumi1 = vec_splats((int32_t)0);
vector signed int vsumi2 = vec_splats((int32_t)0);
vector signed int vsumi3 = vec_splats((int32_t)0);
+ vector signed int vsumi4 = vec_splats((int32_t)0);
+ vector signed int vsumi5 = vec_splats((int32_t)0);
+ vector signed int vsumi6 = vec_splats((int32_t)0);
+ vector signed int vsumi7 = vec_splats((int32_t)0);
- const uint8_t * restrict q5 = x[i].qs;
+
+ const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- for (int j = 0; j < QK_K/64; ++j) {
- __builtin_prefetch(q5, 0, 1);
+ for (int j = 0; j < QK_K/128; ++j) {
+ __builtin_prefetch(q3, 0, 1);
__builtin_prefetch(q8, 0, 1);
- vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
- vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
- q5 += 32;
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
+ q3 += 32;
+ //the low 2 bits
vector signed char qxs00 = vec_and(qxs0, lowMask);
- vector signed char qxs01 = vec_sr(qxs0, v4);
+ vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
+ vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
+ vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
vector signed char qxs10 = vec_and(qxs1, lowMask);
- vector signed char qxs11 = vec_sr(qxs1, v4);
+ vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
+ vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
+ vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
- vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
- vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
- vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
- vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
- qxhs0 = vec_sr(qxhs0, v2);
- qxhs1 = vec_sr(qxhs1, v2);
+ //the 3rd bit
+ vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
+ vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
+ vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
+ vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
+ vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
+ vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
+ vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
+ vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
+ qxhs0 = vec_sr(qxhs0, v4);
+ qxhs1 = vec_sr(qxhs1, v4);
- vector signed char q5x00 = vec_or(q5h00, qxs00);
- vector signed char q5x01 = vec_or(q5h01, qxs01);
- vector signed char q5x10 = vec_or(q5h10, qxs10);
- vector signed char q5x11 = vec_or(q5h11, qxs11);
+ vector signed char q3x00 = vec_sub(qxs00, qxh00);
+ vector signed char q3x01 = vec_sub(qxs01, qxh01);
+ vector signed char q3x02 = vec_sub(qxs02, qxh02);
+ vector signed char q3x03 = vec_sub(qxs03, qxh03);
+ vector signed char q3x10 = vec_sub(qxs10, qxh10);
+ vector signed char q3x11 = vec_sub(qxs11, qxh11);
+ vector signed char q3x12 = vec_sub(qxs12, qxh12);
+ vector signed char q3x13 = vec_sub(qxs13, qxh13);
- vector signed char q8y00 = vec_xl( 0, q8);
- vector signed char q8y10 = vec_xl(16, q8);
- vector signed char q8y01 = vec_xl(32, q8);
- vector signed char q8y11 = vec_xl(48, q8);
- q8 += 64;
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y01 = vec_xl( 32, q8);
+ vector signed char q8y11 = vec_xl( 48, q8);
+ vector signed char q8y02 = vec_xl( 64, q8);
+ vector signed char q8y12 = vec_xl( 80, q8);
+ vector signed char q8y03 = vec_xl( 96, q8);
+ vector signed char q8y13 = vec_xl(112, q8);
+ q8 += 128;
- vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00));
- vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01));
- vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10));
- vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11));
+ vector signed short vscales_h = vec_unpackh(vscales);
+ vector signed short vs0 = vec_splat(vscales_h, 0);
+ vector signed short vs1 = vec_splat(vscales_h, 1);
+ vector signed short vs2 = vec_splat(vscales_h, 2);
+ vector signed short vs3 = vec_splat(vscales_h, 3);
+ vector signed short vs4 = vec_splat(vscales_h, 4);
+ vector signed short vs5 = vec_splat(vscales_h, 5);
+ vector signed short vs6 = vec_splat(vscales_h, 6);
+ vector signed short vs7 = vec_splat(vscales_h, 7);
+ vscales = vec_sld(vscales, vscales, 8);
- vector signed short vs0 = vec_splat(vscales, 0);
- vector signed short vs1 = vec_splat(vscales, 1);
- vscales = vec_sld(vscales, vscales, 12);
+ vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
+ vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
+ vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
+ vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
+ vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
+ vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
+ vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
+ vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
- qv00 = vec_add(qv00, qv10);
- qv01 = vec_add(qv01, qv11);
+ vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
+ vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
+ vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4));
+ vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6));
+ vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
+ vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
+ vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5));
+ vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7));
- vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
- vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
- vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2);
- vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3);
+ vsumi0 = vec_add(vsum0, vsumi0);
+ vsumi1 = vec_add(vsum1, vsumi1);
+ vsumi2 = vec_add(vsum2, vsumi2);
+ vsumi3 = vec_add(vsum3, vsumi3);
+ vsumi4 = vec_add(vsum4, vsumi4);
+ vsumi5 = vec_add(vsum5, vsumi5);
+ vsumi6 = vec_add(vsum6, vsumi6);
+ vsumi7 = vec_add(vsum7, vsumi7);
}
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
+
vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
#elif defined __loongarch_asx
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
- const __m128i mzero = __lsx_vldi(0);
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
+ const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
+ const __m128i m32 = __lsx_vreplgr2vr_b(32);
__m256 acc = (__m256)__lasx_xvldi(0);
- float summs = 0.f;
-
- for (int i = 0; i < nb; ++i) {
+ uint32_t aux[3];
- const uint8_t * restrict q5 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
+ for (int i = 0; i < nb; ++i) {
-#if QK_K == 256
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
-
- memcpy(utmp, x[i].scales, 12);
-#else
- // TODO
- const float d = 0, dmin = 0;
-#endif
-
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
-
- const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
- const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
- const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
- summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
-
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
- const __m256i scales = lasx_insertf128(sc128, sc128);
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ __m128i scales128 = lsx_set_w(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = __lsx_vsub_b(scales128, m32);
+ const __m256i all_scales = lasx_ext8_16(scales128);
+ const __m128i l_scales = lasx_extracti128(all_scales, 0);
+ const __m128i h_scales = lasx_extracti128(all_scales, 1);
+ const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
- const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
- __m256i hmask = mone;
+ // high bit
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
+ // integer accumulator
__m256i sumi = __lasx_xvldi(0);
int bit = 0;
+ int is = 0;
- for (int j = 0; j < QK_K/64; ++j) {
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
- const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
- const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits
+ const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
- const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
+ // prepare low and high bits
+ const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+ ++bit;
- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
- const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
- hmask = __lasx_xvslli_h(hmask, 1);
+ const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+ ++bit;
- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
- const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
- hmask = __lasx_xvslli_h(hmask, 1);
+ const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+ ++bit;
+
+ const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+ ++bit;
+ // load Q8 quants
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
- __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
+ __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
+ __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
+ __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
- p16_0 = lasx_madd_h(scale_0, p16_0);
- p16_1 = lasx_madd_h(scale_1, p16_1);
+ __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
+ __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
+ __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
- sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
- }
+ p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+ p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+ p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+ p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
- __m256 vd = __lasx_xvreplfr2vr_s(d);
- acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+ // multiply with scales
+ p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+ p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+ p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+ p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+ // accumulate
+ p16_0 = __lasx_xvadd_w(p16_0, p16_1);
+ p16_2 = __lasx_xvadd_w(p16_2, p16_3);
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
+ }
+ // multiply with block scale and accumulate
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
}
- *s = hsum_float_8(acc) + summs;
+ *s = hsum_float_8(acc);
#else
-
- const uint8_t * scales = (const uint8_t*)&utmp[0];
- const uint8_t * mins = (const uint8_t*)&utmp[2];
+ // scalar version
+ // This function is written like this so the compiler can manage to vectorize most of it
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+ // The ideal situation would be if we could just write the code once, and the compiler would
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
+ // write vectorized versions for AVX, ARM_NEON, etc.
int8_t aux8[QK_K];
int16_t aux16[8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
+ uint32_t auxs[4];
+ const int8_t * scales = (const int8_t*)auxs;
+
float sumf = 0;
for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q4 = x[i].qs;
- const uint8_t * restrict hm = x[i].qh;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict hm = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
uint8_t m = 1;
- for (int j = 0; j < QK_K/64; ++j) {
- for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
- for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
- for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
- for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
- q4 += 32;
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ q3 += 32;
}
- memcpy(utmp, x[i].scales, 12);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- const uint32_t uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[2] = uaux;
- utmp[0] &= kmask1;
-
- int sumi = 0;
- for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
a = aux8;
- int is = 0;
- for (int j = 0; j < QK_K/32; ++j) {
- int32_t scale = scales[is++];
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
- for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
- q8 += 8; a += 8;
+
+ memcpy(auxs, x[i].scales, 12);
+ uint32_t tmp = auxs[2];
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+ for (int j = 0; j < QK_K/16; ++j) {
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
- const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
- sumf -= dmin * sumi;
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
+
#endif
-}
-#else
+}
-void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(by);
UNUSED(bs);
- const block_q5_K * restrict x = vx;
+ const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
+ static const uint32_t kmask1 = 0x3f3f3f3f;
+ static const uint32_t kmask2 = 0x0f0f0f0f;
+ static const uint32_t kmask3 = 0x03030303;
+
+ uint32_t utmp[4];
+
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
- const uint8x16_t mh = vdupq_n_u8(16);
const int32x4_t mzero = vdupq_n_s32(0);
- ggml_int8x16x4_t q5bytes;
- ggml_uint8x16x4_t q5h;
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x2_t q8bytes;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const int8_t * sc = x[i].scales;
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q5 = x[i].qs;
- const uint8_t * restrict qh = x[i].qh;
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+ memcpy(utmp, x[i].scales, 12);
+
+ uint32x2_t mins8 = { 0 };
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[0] &= kmask1;
+
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+ sumf -= dmin * vaddvq_s32(prod);
+
+ const uint8_t * scales = (const uint8_t *)utmp;
+
+ const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const uint8x8_t qhbits = vld1_u8(qh);
+ int32_t sumi1 = 0;
+ int32_t sumi2 = 0;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
+
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
- const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
- const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
- const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
- q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
- q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
- q5h.val[2] = vbicq_u8(mh, htmp);
- q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
- q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
- q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
- q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
- q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+ }
- int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
- int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
- int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
- int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
+ sumf += d * (sumi1 + sumi2);
- sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
- const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
+ __m128 acc_m = _mm_setzero_ps();
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q5 = x[i].qs;
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
- const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
- const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
- const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
+ __m256i sumi = _mm256_setzero_si256();
- int64_t aux64;
- memcpy(&aux64, x[i].qh, 8);
- const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
- const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
- const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
+ const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
- const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
- const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+ p16l = _mm256_madd_epi16(scale_l, p16l);
- const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
- const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
- const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
- const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+ p16h = _mm256_madd_epi16(scale_h, p16h);
+ const __m256i sumj = _mm256_add_epi32(p16l, p16h);
- const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
+ sumi = _mm256_add_epi32(sumi, sumj);
+ }
- acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
+ __m256 vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
- *s = hsum_float_8(acc);
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
- const __m128i mone = _mm_set1_epi8(1);
+ const __m128i m2 = _mm_set1_epi8(0x2);
__m256 acc = _mm256_setzero_ps();
+ __m128 acc_m = _mm_setzero_ps();
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q5 = x[i].qs;
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
- const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
+ acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
- const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
- const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
- const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
- const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
- int64_t aux64;
- memcpy(&aux64, x[i].qh, 8);
- const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
- const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
- const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
- const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
- const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
+ const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
- const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
- const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
- const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
- const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
+ __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+ const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+ q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+ const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+ const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+ p16l = _mm_madd_epi16(scale_l, p16l);
+ sumi_0 = _mm_add_epi32(sumi_0, p16l);
+ const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+ p16l = _mm_madd_epi16(scale_l, p16l);
+ sumi_1 = _mm_add_epi32(sumi_1, p16l);
- const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
- const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
- const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
- const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
- const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
- const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
- const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
- const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
+ const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+ p16h = _mm_madd_epi16(scale_h, p16h);
+ sumi_0 = _mm_add_epi32(sumi_0, p16h);
+ const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+ p16h = _mm_madd_epi16(scale_h, p16h);
+ sumi_1 = _mm_add_epi32(sumi_1, p16h);
- const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
- const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
+ }
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
+ __m256 vd = _mm256_set1_ps(d);
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
}
- *s = hsum_float_8(acc);
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
#elif defined __riscv_v_intrinsic
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
float sumf = 0;
for (int i = 0; i < nb; ++i) {
+ size_t vl = 8;
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const int8_t * sc = x[i].scales;
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q5 = x[i].qs;
- const uint8_t * restrict qh = x[i].qh;
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+ const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vl = 32;
- // load qh
- vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8);
- vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
+ int32_t sum_1 = 0;
+ int32_t sum_2 = 0;
- size_t vl = 16;
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ // load Q4
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+ // load Q8 and multiply it with lower Q4 nibble
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+ // load Q8 and multiply it with upper Q4 nibble
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+ q4 += 32; q8 += 64;
+
+ }
+
+ sumf += d*(sum_1 + sum_2);
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+ vector float vdmin = vec_mul(vxmin, vyd);
+
+ vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+ vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
- // combine both qh_1 and qh_2
- vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
+ memcpy(utmp, x[i].scales, 12);
- vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
- vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl);
- vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl);
- vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0);
- vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1);
- vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2);
- vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3);
+ vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
+ vector signed short vscales = vec_unpackh(utmps);
+ vector signed short q4xmins = vec_unpackl(utmps);
+ vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
+ vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
- // load q5
- vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl);
- vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl);
+ vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
+ vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
+ vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
+ vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
- vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl));
- vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl));
- vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl));
- vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl));
+ vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+ vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+ vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+ vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
- vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl);
- vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl);
- vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl);
- vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl);
+ vector signed int vsumi0 = vec_splats((int32_t)0);
+ vector signed int vsumi1 = vec_splats((int32_t)0);
+ vector signed int vsumi2 = vec_splats((int32_t)0);
+ vector signed int vsumi3 = vec_splats((int32_t)0);
+ vector signed int vsumi4 = vec_splats((int32_t)0);
+ vector signed int vsumi5 = vec_splats((int32_t)0);
+ vector signed int vsumi6 = vec_splats((int32_t)0);
+ vector signed int vsumi7 = vec_splats((int32_t)0);
- // load Q8 and multiply it with Q5
- vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
- vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
- vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
- vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
- vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
- vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
- vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
- vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+ for (int j = 0; j < QK_K/64; j+=2) {
+ __builtin_prefetch(q4, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
- int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0);
- int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1);
- int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2);
- int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3);
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+ vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
+ vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
+ q4 += 64;
- sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+ vector signed char q4x00 = vec_and(qxs0, lowMask);
+ vector signed char q4x01 = vec_sr(qxs0, v4);
+ vector signed char q4x10 = vec_and(qxs1, lowMask);
+ vector signed char q4x11 = vec_sr(qxs1, v4);
+ vector signed char q4x20 = vec_and(qxs2, lowMask);
+ vector signed char q4x21 = vec_sr(qxs2, v4);
+ vector signed char q4x30 = vec_and(qxs3, lowMask);
+ vector signed char q4x31 = vec_sr(qxs3, v4);
- }
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y01 = vec_xl( 32, q8);
+ vector signed char q8y11 = vec_xl( 48, q8);
+ vector signed char q8y20 = vec_xl( 64, q8);
+ vector signed char q8y30 = vec_xl( 80, q8);
+ vector signed char q8y21 = vec_xl( 96, q8);
+ vector signed char q8y31 = vec_xl(112, q8);
+ q8 += 128;
- *s = sumf;
+ vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00));
+ vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01));
+ vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10));
+ vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11));
+ vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20));
+ vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21));
+ vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30));
+ vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31));
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0xF);
- const vector unsigned char v1 = vec_splats((unsigned char)0x1);
- const vector unsigned char v2 = vec_splats((unsigned char)0x2);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ vector signed short vs0 = vec_splat(vscales, 0);
+ vector signed short vs1 = vec_splat(vscales, 1);
+ vector signed short vs2 = vec_splat(vscales, 2);
+ vector signed short vs3 = vec_splat(vscales, 3);
+ vscales = vec_sld(vscales, vscales, 8);
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
+ qv00 = vec_add(qv00, qv10);
+ qv10 = vec_add(qv01, qv11);
+ qv20 = vec_add(qv20, qv30);
+ qv30 = vec_add(qv21, qv31);
-#pragma GCC unroll 2
- for (int i = 0; i < nb; ++i) {
- __builtin_prefetch(x[i].qs, 0, 1);
- __builtin_prefetch(y[i].qs, 0, 1);
+ vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
+ vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2);
+ vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3);
+ vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4);
+ vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5);
+ vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6);
+ vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7);
+ }
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd= vec_mul(vxd, vyd);
-
- vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs);
- vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs);
- vector signed char qxs00 = (vector signed char)vec_and(qxs0, lowMask);
- vector signed char qxs01 = (vector signed char)vec_sr(qxs0, v4);
- vector signed char qxs10 = (vector signed char)vec_and(qxs1, lowMask);
- vector signed char qxs11 = (vector signed char)vec_sr(qxs1, v4);
-
- vector signed char qxhs = (vector signed char)vec_xl_len(x[i].qh, 8);
- vector signed char qxhs0 = vec_or(qxhs, vec_sr(vec_sld(qxhs, qxhs, 8), v1));
- vector signed char qxhs1 = vec_sr(qxhs0, v2);
- vector signed char qxh00 = vec_sl(vec_andc((vector signed char)v1, qxhs0), v4);
- vector signed char qxh10 = vec_sl(vec_andc((vector signed char)v1, qxhs1), v4);
- vector signed char qxh01 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs0, v4)), v4);
- vector signed char qxh11 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs1, v4)), v4);
-
- vector signed char q5x00 = vec_sub(qxs00, qxh00);
- vector signed char q5x10 = vec_sub(qxs10, qxh10);
- vector signed char q5x01 = vec_sub(qxs01, qxh01);
- vector signed char q5x11 = vec_sub(qxs11, qxh11);
-
- vector signed char q8y00 = vec_xl( 0, y[i].qs);
- vector signed char q8y10 = vec_xl(16, y[i].qs);
- vector signed char q8y01 = vec_xl(32, y[i].qs);
- vector signed char q8y11 = vec_xl(48, y[i].qs);
-
- vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00));
- vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01));
- vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10));
- vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11));
-
- vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4));
- vector signed short vs0 = vec_splat(vs, 0);
- vector signed short vs1 = vec_splat(vs, 1);
- vector signed short vs2 = vec_splat(vs, 2);
- vector signed short vs3 = vec_splat(vs, 3);
-
- vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
- vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
- vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
- vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
#elif defined __loongarch_asx
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
__m256 acc = (__m256)__lasx_xvldi(0);
+ __m128 acc_m = (__m128)__lsx_vldi(0);
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q5 = x[i].qs;
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+
+ const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+ const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
- const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0);
+ const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
+ const __m256i scales = lasx_insertf128(sc128, sc128);
- const __m256i scale_l = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0]));
- const __m256i scale_h = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2]));
+ __m256i sumi = __lasx_xvldi(0);
- int64_t aux64;
- memcpy(&aux64, x[i].qh, 8);
- __m128i haux128 = __lsx_vinsgr2vr_d(haux128, aux64, 0);
- haux128 = __lsx_vinsgr2vr_d(haux128, aux64 >> 1, 1);
- const __m256i haux256 = lasx_insertf128(__lsx_vsrli_h(haux128, 2), haux128);
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvandn_v(haux256, mone), 4);
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvandn_v(__lasx_xvsrli_h(haux256, 4), mone), 4);
+ const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
+ const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4l = __lasx_xvand_v(q4bits, m4);
+ const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
- const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
- const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
+ const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ __m256i p16l = lasx_maddubs_h(q4l, q8l);
+ p16l = lasx_madd_h(scale_l, p16l);
- const __m256i p16_0 = lasx_madd_h(scale_l, lasx_maddubs_h(q5l_0, q8_0));
- const __m256i p16_1 = lasx_madd_h(scale_h, lasx_maddubs_h(q5l_1, q8_1));
- const __m256i s16_0 = lasx_madd_h(scale_l, lasx_maddubs_h(q5h_0, q8_0));
- const __m256i s16_1 = lasx_madd_h(scale_h, lasx_maddubs_h(q5h_1, q8_1));
+ const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ __m256i p16h = lasx_maddubs_h(q4h, q8h);
+ p16h = lasx_madd_h(scale_h, p16h);
+ const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
- const __m256i dot = __lasx_xvsub_w(__lasx_xvadd_w(p16_0, p16_1), __lasx_xvadd_w(s16_0, s16_1));
+ sumi = __lasx_xvadd_w(sumi, sumj);
+ }
- acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(dot), acc);
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
}
- *s = hsum_float_8(acc);
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
+ __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
+
+ ft_union fi;
+ fi.i = __lsx_vpickve2gr_w(acc_m, 0);
+ *s = hsum_float_8(acc) + fi.f ;
#else
- int8_t aux8[QK_K];
- int16_t aux16[16];
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
float sums [8];
+ int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
- const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
- for (int l = 0; l < 32; ++l) {
- a[l+ 0] = q4[l] & 0xF;
- a[l+32] = q4[l] >> 4;
- }
- for (int is = 0; is < 8; ++is) {
- uint8_t m = 1 << is;
- for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
+ for (int j = 0; j < QK_K/64; ++j) {
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+ a += 32;
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
+ a += 32; q4 += 32;
}
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
- const int8_t * restrict sc = x[i].scales;
-
- for (int j = 0; j < QK_K/16; ++j) {
- const float dl = d * sc[j];
- for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
- for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
- q8 += 16; a += 16;
+ int sumi = 0;
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+ a = aux8;
+ int is = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ int32_t scale = scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
}
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+ sumf -= dmin * sumi;
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
-#endif
-
-#if QK_K == 256
-void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(by);
UNUSED(bs);
- const block_q6_K * restrict x = vx;
+ const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
-#ifdef __ARM_NEON
- float sum = 0;
-
- const uint8x16_t m4b = vdupq_n_u8(0xF);
- const int32x4_t vzero = vdupq_n_s32(0);
- //const int8x16_t m32s = vdupq_n_s8(32);
+ static const uint32_t kmask1 = 0x3f3f3f3f;
+ static const uint32_t kmask2 = 0x0f0f0f0f;
+ static const uint32_t kmask3 = 0x03030303;
- const uint8x16_t mone = vdupq_n_u8(3);
+ uint32_t utmp[4];
- ggml_int8x16x4_t q6bytes;
- ggml_uint8x16x4_t q6h;
+#ifdef __ARM_NEON
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const uint8x16_t mone = vdupq_n_u8(1);
+ const uint8x16_t mtwo = vdupq_n_u8(2);
+ const int32x4_t mzero = vdupq_n_s32(0);
- for (int i = 0; i < nb; ++i) {
+ ggml_int8x16x4_t q5bytes;
- const float d_all = GGML_FP16_TO_FP32(x[i].d);
+ float sumf = 0;
- const uint8_t * restrict q6 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
+ for (int i = 0; i < nb; ++i) {
- const int8_t * restrict scale = x[i].scales;
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
- const int8x16_t scales = vld1q_s8(scale);
- const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
- const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
- vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
- vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
- vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
- int32_t isum_mins = vaddvq_s32(prod);
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- int32_t isum = 0;
+ const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+ int32_t sumi_mins = vaddvq_s32(prod);
- for (int j = 0; j < QK_K/128; ++j) {
+ const uint8_t * scales = (const uint8_t *)utmp;
- ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
- ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
- ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
- q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
- q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
- uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
- q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits.val[1], 2);
- q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
- //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
- //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
- //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
- q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
- q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
- q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
- q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+ ggml_uint8x16x4_t q5h;
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+ int32_t sumi = 0;
- scale += 4;
+ for (int j = 0; j < QK_K/64; ++j) {
- q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
- shifted = vshrq_n_u8(qhbits.val[0], 4);
- q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits.val[1], 4);
- q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits.val[0], 6);
- q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits.val[1], 6);
- q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+ q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+ q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+ q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
- //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
- //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
- //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
- //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
- q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
- q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
- q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
- q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+ q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+ q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+ q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+ q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
- scale += 4;
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
}
- //sum += isum * d_all * y[i].d;
- sum += d_all * y[i].d * (isum - 32 * isum_mins);
+ sumf += d * sumi - dmin * sumi_mins;
}
- *s = sum;
+
+ *s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
- const __m256i m2 = _mm256_set1_epi8(3);
- const __m256i m32s = _mm256_set1_epi8(32);
+ const __m128i mzero = _mm_setzero_si128();
+ const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
+ float summs = 0.f;
+
for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q4 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+ summs += dmin * _mm_extract_epi32(hsum, 0);
+
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+ __m256i hmask = mone;
__m256i sumi = _mm256_setzero_si256();
- int is = 0;
+ int bit = 0;
- for (int j = 0; j < QK_K/128; ++j) {
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
- const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
- const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
- const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
- is += 4;
+ const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
- const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
- const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
- const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
- const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
- const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
- const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
- const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
+ hmask = _mm256_slli_epi16(hmask, 1);
- const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
- const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
- const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
- const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
+ hmask = _mm256_slli_epi16(hmask, 1);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-
- __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
- __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
- __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
- __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
-
- __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
- __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
- __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
- __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
- p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
- p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
- p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
- p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
- p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
- acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+ __m256 vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
}
- *s = hsum_float_8(acc);
+ *s = hsum_float_8(acc) + summs;
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
- const __m128i m3 = _mm_set1_epi8(3);
- const __m128i m32s = _mm_set1_epi8(32);
+ const __m128i mzero = _mm_setzero_si128();
+ const __m128i mone = _mm_set1_epi8(1);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
+ float summs = 0.f;
+
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- const uint8_t * restrict q4 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
+ const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- __m128i sumi_0 = _mm_setzero_si128();
- __m128i sumi_1 = _mm_setzero_si128();
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
- __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
- for (int j = 0; j < QK_K/128; ++j) {
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+ summs += dmin * _mm_extract_epi32(hsum, 0);
- const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
- const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+ __m128i hmask = mone;
- const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
- const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
- const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
- const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
- const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
- const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
- const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ int bit = 0;
- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
- const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
- const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
- __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
- __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
- __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
- __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
+ const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+ const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
- __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
- __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
- __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
- __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
- __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
- __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
- __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
- __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+ __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+ __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+ __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+ __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+ __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
+ __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
+ hmask = _mm_slli_epi16(hmask, 1);
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+ __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm_madd_epi16(scale_0, p16_1);
- const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi8(shuffle, m2);
- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi8(shuffle, m2);
- const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi8(shuffle, m2);
- const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
- shuffle = _mm_add_epi8(shuffle, m2);
+ q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+ q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+ q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+ q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+ q5_0 = _mm_add_epi8(q5l_0, q5h_0);
+ q5_1 = _mm_add_epi8(q5l_1, q5h_1);
+ hmask = _mm_slli_epi16(hmask, 1);
- p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
- p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
- p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
- p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
- p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
+ q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+ __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+ p16_2 = _mm_madd_epi16(scale_1, p16_2);
+ p16_3 = _mm_madd_epi16(scale_1, p16_3);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
}
+ __m256 vd = _mm256_set1_ps(d);
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
}
- *s = hsum_float_8(acc);
+ *s = hsum_float_8(acc) + summs;
#elif defined __riscv_v_intrinsic
- float sumf = 0;
- for (int i = 0; i < nb; ++i) {
-
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
- const uint8_t * restrict q6 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
+ float sumf = 0;
+ float sums = 0.0;
- const int8_t * restrict scale = x[i].scales;
+ size_t vl;
- size_t vl;
+ for (int i = 0; i < nb; ++i) {
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vl = 8;
- int sum_t = 0;
- int is = 0;
+ const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict hm = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
- for (int j = 0; j < QK_K/128; ++j) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
- vl = 32;
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
- // load qh
- vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
- // load Q6
- vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
- vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
- vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
- vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
- vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
- vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
- vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
- vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
- vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
- vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+ vl = 32;
+ int32_t aux32 = 0;
+ int is = 0;
- vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
- vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
- vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
- vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+ uint8_t m = 1;
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
- vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
- vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
- vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
- vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+ for (int j = 0; j < QK_K/64; ++j) {
+ // load Q5 and Q8
+ vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+ vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+ vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
- // load Q8 and take product
- vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
- vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
- vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
- vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+ // compute mask for addition
+ vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+ vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
+ m <<= 1;
- vl = 16;
+ vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+ vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
+ m <<= 1;
- vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
- vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
- vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
- vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
- vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
- vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
- vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
- vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+ vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+ vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+ vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+ vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
- q6 += 64; qh += 32; q8 += 128; is=8;
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+ q5 += 32; q8 += 64;
}
- sumf += d * sum_t;
+ vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+ sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
}
- *s = sumf;
+ *s = sumf+sums;
#elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector unsigned char v1 = vec_splats((unsigned char)0x1);
const vector unsigned char v2 = vec_splats((unsigned char)0x2);
const vector unsigned char v3 = vec_splats((unsigned char)0x3);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
- const vector unsigned char v6 = vec_splats((unsigned char)0x6);
- const vector signed char off = vec_splats((signed char)0x20);
vector float vsumf0 = vec_splats(0.0f);
vector float vsumf1 = vec_splats(0.0f);
vector float vyd = vec_splats(y[i].d);
vector float vd = vec_mul(vxd, vyd);
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+ vector float vdmin = vec_mul(vxmin, vyd);
+
+ memcpy(utmp, x[i].scales, 12);
+
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+ vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+ vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
+ vector signed short vscales = vec_unpackh(utmps);
+
+ vector signed short q5xmins = vec_unpackl(utmps);
+ vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
+ vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+
+ vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
+ vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
+ vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
+ vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+
+ vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+ vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+ vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+ vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+
vector signed int vsumi0 = vec_splats((int32_t)0);
vector signed int vsumi1 = vec_splats((int32_t)0);
vector signed int vsumi2 = vec_splats((int32_t)0);
vector signed int vsumi3 = vec_splats((int32_t)0);
- vector signed int vsumi4 = vec_splats((int32_t)0);
- vector signed int vsumi5 = vec_splats((int32_t)0);
- vector signed int vsumi6 = vec_splats((int32_t)0);
- vector signed int vsumi7 = vec_splats((int32_t)0);
- const uint8_t * restrict q6 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict qs = x[i].scales;
+ const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- for (int j = 0; j < QK_K/128; ++j) {
- __builtin_prefetch(q6, 0, 0);
- __builtin_prefetch(qh, 0, 0);
- __builtin_prefetch(q8, 0, 0);
+ for (int j = 0; j < QK_K/64; ++j) {
+ __builtin_prefetch(q5, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
- vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
- vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
- vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
- vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
- q6 += 64;
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
+ q5 += 32;
vector signed char qxs00 = vec_and(qxs0, lowMask);
vector signed char qxs01 = vec_sr(qxs0, v4);
vector signed char qxs10 = vec_and(qxs1, lowMask);
vector signed char qxs11 = vec_sr(qxs1, v4);
- vector signed char qxs20 = vec_and(qxs2, lowMask);
- vector signed char qxs21 = vec_sr(qxs2, v4);
- vector signed char qxs30 = vec_and(qxs3, lowMask);
- vector signed char qxs31 = vec_sr(qxs3, v4);
-
- vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
- vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
- qh += 32;
-
- vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
- vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
- vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
- vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
- vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
- vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
- vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
- vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
- vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
- vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
- vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
- vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
- vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
- vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
- vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
- vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
+ vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
+ vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
+ vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
+ vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
+ qxhs0 = vec_sr(qxhs0, v2);
+ qxhs1 = vec_sr(qxhs1, v2);
- vector signed char q8y00 = vec_xl( 0, q8);
- vector signed char q8y10 = vec_xl( 16, q8);
- vector signed char q8y20 = vec_xl( 32, q8);
- vector signed char q8y30 = vec_xl( 48, q8);
- vector signed char q8y01 = vec_xl( 64, q8);
- vector signed char q8y11 = vec_xl( 80, q8);
- vector signed char q8y21 = vec_xl( 96, q8);
- vector signed char q8y31 = vec_xl(112, q8);
- q8 += 128;
+ vector signed char q5x00 = vec_or(q5h00, qxs00);
+ vector signed char q5x01 = vec_or(q5h01, qxs01);
+ vector signed char q5x10 = vec_or(q5h10, qxs10);
+ vector signed char q5x11 = vec_or(q5h11, qxs11);
- vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
- vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
- vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
- vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
- vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
- vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
- vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
- vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl(16, q8);
+ vector signed char q8y01 = vec_xl(32, q8);
+ vector signed char q8y11 = vec_xl(48, q8);
+ q8 += 64;
- vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
- qs += 8;
+ vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00));
+ vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01));
+ vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10));
+ vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11));
vector signed short vs0 = vec_splat(vscales, 0);
vector signed short vs1 = vec_splat(vscales, 1);
- vector signed short vs2 = vec_splat(vscales, 2);
- vector signed short vs3 = vec_splat(vscales, 3);
- vector signed short vs4 = vec_splat(vscales, 4);
- vector signed short vs5 = vec_splat(vscales, 5);
- vector signed short vs6 = vec_splat(vscales, 6);
- vector signed short vs7 = vec_splat(vscales, 7);
+ vscales = vec_sld(vscales, vscales, 12);
+
+ qv00 = vec_add(qv00, qv10);
+ qv01 = vec_add(qv01, qv11);
vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
- vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2);
- vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3);
- vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4);
- vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5);
- vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6);
- vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7);
-
- vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0);
- vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1);
- vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2);
- vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3);
- vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4);
- vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5);
- vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6);
- vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7);
+ vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2);
+ vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3);
}
- vsumi0 = vec_add(vsumi0, vsumi4);
- vsumi1 = vec_add(vsumi1, vsumi5);
- vsumi2 = vec_add(vsumi2, vsumi6);
- vsumi3 = vec_add(vsumi3, vsumi7);
-
vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
#elif defined __loongarch_asx
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
- const __m256i m2 = __lasx_xvreplgr2vr_b(3);
- const __m256i m32s = __lasx_xvreplgr2vr_b(32);
+ const __m128i mzero = __lsx_vldi(0);
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
__m256 acc = (__m256)__lasx_xvldi(0);
- for (int i = 0; i < nb; ++i) {
+ float summs = 0.f;
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q4 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
+ const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
- const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+
+ const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+ const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+ const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
+ summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
+
+ const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
+ const __m256i scales = lasx_insertf128(sc128, sc128);
+
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
+ __m256i hmask = mone;
__m256i sumi = __lasx_xvldi(0);
- int is = 0;
+ int bit = 0;
- for (int j = 0; j < QK_K/128; ++j) {
+ for (int j = 0; j < QK_K/64; ++j) {
- const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
- const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
- const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
- const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
- is += 4;
+ const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
- const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
- const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
- const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
+ const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
- const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
- const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+ const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
+ const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
+ hmask = __lasx_xvslli_h(hmask, 1);
- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
- const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
- const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+ const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
+ const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
+ hmask = __lasx_xvslli_h(hmask, 1);
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-
- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
- __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
- __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
-
- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
- __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
- __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+ __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
- p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
- p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+ p16_0 = lasx_madd_h(scale_0, p16_0);
+ p16_1 = lasx_madd_h(scale_1, p16_1);
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
- sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
}
- acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
}
- *s = hsum_float_8(acc);
+ *s = hsum_float_8(acc) + summs;
#else
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
- const uint8_t * restrict q4 = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
+ const uint8_t * restrict q4 = x[i].qs;
+ const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
- for (int j = 0; j < QK_K; j += 128) {
- for (int l = 0; l < 32; ++l) {
- a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- }
- a += 128;
- q4 += 64;
- qh += 32;
+ uint8_t m = 1;
+ for (int j = 0; j < QK_K/64; ++j) {
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ a += 32; m <<= 1;
+ q4 += 32;
}
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ int sumi = 0;
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
a = aux8;
int is = 0;
- for (int j = 0; j < QK_K/16; ++j) {
- int scale = x[i].scales[is++];
+ for (int j = 0; j < QK_K/32; ++j) {
+ int32_t scale = scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+ sumf -= dmin * sumi;
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
-#else
-
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);
- const int8x16_t m32s = vdupq_n_s8(32);
const int32x4_t vzero = vdupq_n_s32(0);
+ //const int8x16_t m32s = vdupq_n_s8(32);
const uint8x16_t mone = vdupq_n_u8(3);
const int8_t * restrict scale = x[i].scales;
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+ const int8x16_t scales = vld1q_s8(scale);
+ const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
+
+ const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+ vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+ vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+ int32_t isum_mins = vaddvq_s32(prod);
+
int32_t isum = 0;
- uint8x16_t qhbits = vld1q_u8(qh);
- ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
- ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+ ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+ uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 2);
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+
+ scale += 4;
- q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
- uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
- q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits, 4);
- q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- shifted = vshrq_n_u8(qhbits, 6);
- q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
- q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
- q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
- q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
- q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
+ shifted = vshrq_n_u8(qhbits.val[0], 4);
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 4);
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[0], 6);
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 6);
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
- isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
- vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
- sum += isum * d_all * y[i].d;
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+ scale += 4;
+ }
+ //sum += isum * d_all * y[i].d;
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
*s = sum;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
- const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
- const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
- const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
- const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
__m256i sumi = _mm256_setzero_si256();
- const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
- const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+ is += 4;
+
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
- const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
- const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+ const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+ const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
- const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
- const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+ const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+ const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
- const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
- const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+ __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+ __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
- __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
- __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+ __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+ __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
- __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
- __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
- p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+ p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+ p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
- p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
- p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+ }
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
- const __m128i m2 = _mm_set1_epi8(3);
+ const __m128i m3 = _mm_set1_epi8(3);
const __m128i m32s = _mm_set1_epi8(32);
+ const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
- const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
- const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
- const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
- const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
- const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
- const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+ __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+ const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
+ const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
+ const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
+
+ const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
- const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
- const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
- const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
- const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
+ __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
+ __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
+ __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
+ __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+ __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+ __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
- __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
- __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
- __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
- __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
- __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
- __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
- __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
- __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
+ p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
- p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
- p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+ }
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
float sumf = 0;
-
for (int i = 0; i < nb; ++i) {
- const float d_all = GGML_FP16_TO_FP32(x[i].d);
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict q8 = y[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
- int32_t isum = 0;
-
- size_t vl = 16;
+ size_t vl;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
- // load Q6
- vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl);
- vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl);
+ int sum_t = 0;
+ int is = 0;
- // load qh
- vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl);
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ vl = 32;
+
+ // load qh
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+ // load Q6
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+ // load Q8 and take product
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
- vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
- qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
- vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
- qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
- vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
- qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
- vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+ vl = 16;
+
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+ q6 += 64; qh += 32; q8 += 128; is=8;
+
+ }
+
+ sumf += d * sum_t;
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+ const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+ const vector signed char off = vec_splats((signed char)0x20);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
- vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl);
- vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl);
- vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl);
- vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl);
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = vec_splats((int32_t)0);
+ vector signed int vsumi1 = vec_splats((int32_t)0);
+ vector signed int vsumi2 = vec_splats((int32_t)0);
+ vector signed int vsumi3 = vec_splats((int32_t)0);
+ vector signed int vsumi4 = vec_splats((int32_t)0);
+ vector signed int vsumi5 = vec_splats((int32_t)0);
+ vector signed int vsumi6 = vec_splats((int32_t)0);
+ vector signed int vsumi7 = vec_splats((int32_t)0);
- vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl);
- vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl);
- vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl);
- vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl);
+ const uint8_t * restrict q6 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict qs = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
- // load Q8 and take product
- vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
- vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
- vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
- vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+ for (int j = 0; j < QK_K/128; ++j) {
+ __builtin_prefetch(q6, 0, 0);
+ __builtin_prefetch(qh, 0, 0);
+ __builtin_prefetch(q8, 0, 0);
- vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
- vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
- vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
- vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
+ vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
+ vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
+ q6 += 64;
- isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2];
- isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3];
+ vector signed char qxs00 = vec_and(qxs0, lowMask);
+ vector signed char qxs01 = vec_sr(qxs0, v4);
+ vector signed char qxs10 = vec_and(qxs1, lowMask);
+ vector signed char qxs11 = vec_sr(qxs1, v4);
+ vector signed char qxs20 = vec_and(qxs2, lowMask);
+ vector signed char qxs21 = vec_sr(qxs2, v4);
+ vector signed char qxs30 = vec_and(qxs3, lowMask);
+ vector signed char qxs31 = vec_sr(qxs3, v4);
- sumf += isum * d_all * y[i].d;
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
+ qh += 32;
- }
+ vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
+ vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
+ vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
+ vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
+ vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
+ vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
+ vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
+ vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
- *s = sumf;
+ vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
+ vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
+ vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
+ vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
+ vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
+ vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
+ vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
+ vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
-#elif defined(__POWER9_VECTOR__)
- const vector signed char lowMask = vec_splats((signed char)0xF);
- const vector unsigned char v2 = vec_splats((unsigned char)0x2);
- const vector unsigned char v3 = vec_splats((unsigned char)0x3);
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
- const vector unsigned char v6 = vec_splats((unsigned char)0x6);
- const vector signed char off = vec_splats((signed char)0x20);
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y20 = vec_xl( 32, q8);
+ vector signed char q8y30 = vec_xl( 48, q8);
+ vector signed char q8y01 = vec_xl( 64, q8);
+ vector signed char q8y11 = vec_xl( 80, q8);
+ vector signed char q8y21 = vec_xl( 96, q8);
+ vector signed char q8y31 = vec_xl(112, q8);
+ q8 += 128;
- vector float vsumf0 = vec_splats(0.0f);
- vector float vsumf1 = vec_splats(0.0f);
- vector float vsumf2 = vec_splats(0.0f);
- vector float vsumf3 = vec_splats(0.0f);
+ vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
+ vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
+ vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
+ vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
+ vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
+ vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
+ vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
+ vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
-#pragma GCC unroll 2
- for (int i = 0; i < nb; ++i) {
- __builtin_prefetch(x[i].ql, 0, 1);
- __builtin_prefetch(x[i].qh, 0, 1);
- __builtin_prefetch(y[i].qs, 0, 1);
+ vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
+ qs += 8;
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
- vector float vyd = vec_splats(y[i].d);
- vector float vd= vec_mul(vxd, vyd);
+ vector signed short vs0 = vec_splat(vscales, 0);
+ vector signed short vs1 = vec_splat(vscales, 1);
+ vector signed short vs2 = vec_splat(vscales, 2);
+ vector signed short vs3 = vec_splat(vscales, 3);
+ vector signed short vs4 = vec_splat(vscales, 4);
+ vector signed short vs5 = vec_splat(vscales, 5);
+ vector signed short vs6 = vec_splat(vscales, 6);
+ vector signed short vs7 = vec_splat(vscales, 7);
- vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].ql);
- vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].ql);
- vector signed char qxs00 = vec_and(qxs0, lowMask);
- vector signed char qxs01 = vec_sr(qxs0, v4);
- vector signed char qxs10 = vec_and(qxs1, lowMask);
- vector signed char qxs11 = vec_sr(qxs1, v4);
+ vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
+ vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2);
+ vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3);
+ vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4);
+ vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5);
+ vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6);
+ vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7);
- vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+ vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0);
+ vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1);
+ vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2);
+ vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3);
+ vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4);
+ vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5);
+ vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6);
+ vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7);
+ }
- vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
- vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
- vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
- vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
-
- vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
- vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
- vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
- vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
-
- vector signed char q8y00 = vec_xl( 0, y[i].qs);
- vector signed char q8y10 = vec_xl(16, y[i].qs);
- vector signed char q8y01 = vec_xl(32, y[i].qs);
- vector signed char q8y11 = vec_xl(48, y[i].qs);
-
- vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
- vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
- vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
- vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
-
- vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4));
- vector signed short vs0 = vec_splat(vs, 0);
- vector signed short vs1 = vec_splat(vs, 1);
- vector signed short vs2 = vec_splat(vs, 2);
- vector signed short vs3 = vec_splat(vs, 3);
-
- vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
- vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
- vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
- vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
- const __m64 scales_1 = __lasx_xvreplgr2vr_b(x[i].scales[0]);
- const __m64 scales_2 = __lasx_xvreplgr2vr_b(x[i].scales[1]);
- const __m64 scales_3 = __lasx_xvreplgr2vr_b(x[i].scales[2]);
- const __m64 scales_4 = __lasx_xvreplgr2vr_b(x[i].scales[3]);
+ const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
__m256i sumi = __lasx_xvldi(0);
- __m128i scale_0 = __lsx_vinsgr2vr_d(scale_0, scales_1, 0);
- scale_0 = __lsx_vinsgr2vr_d(scale_0, scales_2, 1);
- __m128i scale_1 = __lsx_vinsgr2vr_d(scale_1, scales_3, 0);
- scale_1 = __lsx_vinsgr2vr_d(scale_1, scales_4, 1);
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
+ const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
+ const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
+ const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
+ is += 4;
- const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0);
- const __m128i q4bitsH = __lsx_vld((const __m128i*)qh, 0);
+ const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4);
- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4);
+ const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
+ const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
+ const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
+ const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_1);
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
- const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
- const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
+ __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
+ __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
+ __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
+ __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
+ __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
+ __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
+ __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+ p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+ p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+ p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+ p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
+ p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
+ p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
+ p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
+ p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
- sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
+ }
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+ acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
}
*s = hsum_float_8(acc);
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
- for (int l = 0; l < 16; ++l) {
- a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ }
+ a += 128;
+ q4 += 64;
+ qh += 32;
}
+ a = aux8;
int is = 0;
for (int j = 0; j < QK_K/16; ++j) {
int scale = x[i].scales[is++];
#endif
}
-#endif
-
#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
-#if QK_K == 64
- static const uint8_t k_bit_helper[16] = {
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- };
- const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper);
- const __m128i m511 = _mm_set1_epi16(511);
- typedef union {
- __m128i vec_index;
- uint16_t index[8];
- } index_t;
-
- index_t idx;
- __m256 accumf = _mm256_setzero_ps();
- for (int i = 0; i < nb; ++i) {
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs);
- idx.vec_index = _mm_and_si128(q2_data, m511);
-
- const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9);
- const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13);
- const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper);
-
- const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
- const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
- const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits);
-
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
- const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
-
- const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
- iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
- const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
- iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
-
- __m256i signs;
- signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1);
- signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
- const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
-
- signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2);
- signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
- const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
-
- const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
- const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
-
- const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
- const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
-
- const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
-
- accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf);
-
- }
-
- *s = 0.125f * hsum_float_8(accumf);
-#else
-
static const uint8_t k_bit_helper[32] = {
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
}
*s = 0.125f * hsum_float_8(accumf);
-#endif
#elif defined(__loongarch_asx)
const __m256i mone = __lasx_xvreplgr2vr_b(1);
const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
-#if QK_K == 64
- static const uint8_t k_bit_helper[16] = {
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- };
- const __m128i bit_helper = __lsx_vld((const __m128i*)k_bit_helper, 0);
- const __m128i m511 = __lsx_vreplgr2vr_h(511);
- typedef union {
- __m128i vec_index;
- uint16_t index[8];
- } index_t;
-
- index_t idx;
- __m256 accumf = (__m256)__lasx_xvldi(0);
- for (int i = 0; i < nb; ++i) {
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
- const __m128i q2_data = __lsx_vld((const __m128i*)x[i].qs, 0);
- idx.vec_index = __lsx_vand_v(q2_data, m511);
-
- const __m128i partial_sign_bits = __lsx_vsrli_h(q2_data, 9);
- const __m128i partial_sign_bits_upper = __lsx_vsrli_h(q2_data, 13);
- const __m128i partial_sign_bits_for_counting = __lsx_vxor_v(partial_sign_bits, partial_sign_bits_upper);
-
- const __m128i odd_bits = lsx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
- const __m128i full_sign_bits = __lsx_vor_v(partial_sign_bits, odd_bits);
- const __m256i full_signs = lasx_insertf128(full_sign_bits, full_sign_bits);
-
- const __m256i q8_1 = __lasx_xvld((const __m256i *)y[i].qs, 0);
- const __m256i q8_2 = __lasx_xvld((const __m256i *)(y[i].qs+32), 0);
-
- const __m256i q2_1 = lasx_set_d(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
- iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
- const __m256i q2_2 = lasx_set_d(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
- iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
- __m256i signs;
- signs = lasx_shuffle_b(full_signs, block_sign_shuffle_1);
- signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
- const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
-
- signs = lasx_shuffle_b(full_signs, block_sign_shuffle_2);
- signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
- const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
-
- const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
- const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
-
- const __m256i sc1 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[0] & 0xf)+1));
- const __m256i sc2 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[1] & 0xf)+1));
-
- const __m256i sum = __lasx_xvadd_w(lasx_madd_h(sc1, dot1), lasx_madd_h(sc2, dot2));
-
- accumf = __lasx_vfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sum), accumf);
- }
-
- *s = 0.125f * hsum_float_8(accumf);
-#else
-
static const uint8_t k_bit_helper[32] = {
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
}
*s = 0.125f * hsum_float_8(accumf);
-#endif
-
-
#elif defined(__POWER9_VECTOR__)
vector float vsumf0 = vec_splats(0.0f);
vector float vsumf1 = vec_splats(0.0f);
ggml_int8x16x4_t q8b;
vec_index_t idx;
-#if QK_K == 256
uint32_t scales32[2];
const uint8_t * scales8 = (const uint8_t *)scales32;
-#endif
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
const int8_t * restrict q8 = y[i].qs;
-#if QK_K == 256
memcpy(scales32, x[i].scales, 4);
scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
-#endif
int sumi1 = 0, sumi2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
-#if QK_K == 256
+
sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
-#else
- sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
- sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
-#endif
}
sumf += d*(sumi1 + sumi2);
}
const int nb = n / QK_K;
-#if QK_K != 64
iq1m_scale_t scale;
-#endif
#if defined __ARM_NEON
-
-#if QK_K == 64
- const int32x4_t mask = vdupq_n_s32(0xf);
-#else
const int32x4_t mask = vdupq_n_s32(0x7);
-#endif
const int32x4_t mone = vdupq_n_s32(1);
const int32x4_t mzero = vdupq_n_s32(0);
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
-#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-#endif
int32x4_t sumi1 = mzero;
int32x4_t sumi2 = mzero;
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
const int32x4_t p34 = vpaddq_s32(p3, p4);
-#if QK_K == 64
- int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12);
-#else
int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
-#endif
+
scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
sumi1 = vmlaq_s32(sumi1, scales_4, p12);
}
-#if QK_K == 64
- sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
-#else
sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
-#endif
}
*s = sumf;
#elif defined __AVX2__
-#if QK_K == 64
- const __m256i mask = _mm256_set1_epi16(0xf);
-#else
const __m256i mask = _mm256_set1_epi16(0x7);
-#endif
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
-#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-#endif
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
-#if QK_K == 64
- __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
- __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
-#else
+
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
-#endif
+
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
qs += 8; qh += 4;
}
-#if QK_K == 64
- const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
-#else
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
-#endif
+
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
-
}
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
-#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-#endif
int sumi1 = 0, sumi2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
sum1[l/2] += lsum1;
sum2[l/2] += lsum2*delta[l];
}
-#if QK_K == 64
- const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1;
- const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1;
-#else
+
const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
-#endif
+
sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
qs += 4;
qh += 2;
}
-#if QK_K == 64
- sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
-#else
sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
-#endif
}
*s = sumf;
UNUSED(by);
UNUSED(bs);
assert(n % QK_K == 0);
-#if QK_K == 64
- ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc);
-#else
const block_iq4_xs * restrict x = vx;
const block_q8_K * restrict y = vy;
}
*s = sumf;
#endif
-#endif
}
// ================================ IQ2 quantization =============================================
const float * xx;
for (int ibl = 0; ibl < nbl; ++ibl) {
-
-#if QK_K == 64
- y[ibl].d = GGML_FP32_TO_FP16(0.f);
-#endif
memset(y[ibl].qs, 0, QK_K/8);
memset(y[ibl].qh, 0, QK_K/16);
memset(y[ibl].scales, 0, QK_K/32);
}
uint16_t * sc = (uint16_t *)y[ibl].scales;
-#if QK_K == 64
- float d = max_scale/31;
-#else
float d = max_scale/15;
-#endif
float id = 1/d;
float sumqx_f = 0, sumq2_f = 0;
for (int ib = 0; ib < QK_K/block_size; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
-#if QK_K == 64
- l = MAX(0, MIN(15, l));
- sc[ib/4] |= (l << 4*(ib%4));
-#else
l = MAX(0, MIN(7, l));
sc[ib/4] |= (l << 3*(ib%4));
-#endif
y[ibl].qh[ib] |= masks[shifts[ib]];
const float * xb = xbl + block_size*ib;
if (quant_weights) {
}
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
-#if QK_K == 64
- y[ibl].d = s.f16;
-#else
sc[0] |= ((s.u16 & 0x000f) << 12);
sc[1] |= ((s.u16 & 0x00f0) << 8);
sc[2] |= ((s.u16 & 0x0f00) << 4);
sc[3] |= ((s.u16 & 0xf000) << 0);
-#endif
}
}
}
size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
-#if QK_K == 64
- return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights);
-#else
GGML_ASSERT(n_per_row%QK_K == 0);
int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
qrow += nblock*sizeof(block_iq4_xs);
}
return nrow * nblock * sizeof(block_iq4_xs);
-#endif
}
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
} break;
case GGML_TYPE_Q4_K:
{
- #ifdef GGML_QKK_64
- VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
- #else
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
- #endif
} break;
case GGML_TYPE_Q5_K:
{
- #ifdef GGML_QKK_64
- VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
- #else
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
- #endif
} break;
case GGML_TYPE_Q6_K:
{
{
const block_iq1_m * q = (const block_iq1_m *) data;
for (size_t i = 0; i < nb; ++i) {
- #if QK_K == 64
- if (!validate_fp16(q[i].d, i)) {
- return false;
- }
- #else
iq1m_scale_t scale;
const uint16_t * sc = (const uint16_t *)q[i].scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
if (!validate_fp16(scale.f16, i)) {
return false;
}
- #endif
}
} break;
case GGML_TYPE_IQ2_XXS:
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
} break;
case GGML_TYPE_IQ4_XS:
- #if QK_K != 64
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
} break;
- #endif
- // with QK_K == 64, iq4_xs is iq4_nl
case GGML_TYPE_IQ4_NL:
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);