// ===================== Helper functions
//
static inline int nearest_int(float fval) {
- assert(fval <= 4194303.f);
+ assert(fabsf(fval) <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
return nrow * row_size;
}
+// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
+
+void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int64_t i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK_K; j++) {
+ const float v = x[j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ // 5 elements per byte, along 32 bytes
+ for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
+ for (size_t m = 0; m < 32; ++m) {
+ uint8_t q = 0;
+ for (size_t n = 0; n < 5; ++n) {
+ int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
+ q *= 3;
+ q += xi;
+ }
+ // ceiling division (243 == pow(3, 5))
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+ y[i].qs[j + m] = q;
+ }
+ x += 5*32;
+ }
+ // along 16 bytes
+ for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
+ for (size_t m = 0; m < 16; ++m) {
+ uint8_t q = 0;
+ for (size_t n = 0; n < 5; ++n) {
+ int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
+ q *= 3;
+ q += xi;
+ }
+ // ceiling division (243 == pow(3, 5))
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+ y[i].qs[j + m] = q;
+ }
+ x += 5*16;
+ }
+ // 4 elements per byte
+ for (size_t j = 0; j < sizeof(y->qh); ++j) {
+ uint8_t q = 0;
+ for (size_t m = 0; m < 4; ++m) {
+ // -1, 0, 1 -> 0, 1, 2
+ int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
+ q *= 3;
+ q += xi;
+ }
+ // shift the first value to the most significant trit
+ q *= 3;
+ // ceiling division (243 == pow(3, 5))
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+ y[i].qh[j] = q;
+ }
+ x += 4*sizeof(y->qh);
+ }
+}
+
+void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int64_t i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK_K; j++) {
+ const float v = x[j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ for (size_t j = 0; j < sizeof(y->qs); j += 32) {
+ for (size_t m = 0; m < 32; ++m) {
+ uint8_t q = 0;
+ for (size_t n = 0; n < 4; ++n) {
+ // -1, 0, 1 -> 0, 1, 2
+ int xi = lroundf(x[m + n*32] * id) + 1;
+ q += (xi & 3) << (2*n);
+ }
+ y[i].qs[j + m] = q;
+ }
+ x += 4*32;
+ }
+ }
+}
+
+void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_tq1_0 * restrict y = vy;
+ quantize_row_tq1_0_ref(x, y, k);
+}
+
+void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_tq2_0 * restrict y = vy;
+ quantize_row_tq2_0_ref(x, y, k);
+}
+
+size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ (void)quant_weights; // not used
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
+ quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * row_size;
+}
+
+size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ (void)quant_weights; // not used
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
+ quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * row_size;
+}
+
+
+void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
+
+ for (int64_t i = 0; i < nb; ++i) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
+ for (size_t n = 0; n < 5; ++n) {
+ for (size_t m = 0; m < 32; ++m) {
+ uint8_t q = x[i].qs[j + m] * pow3[n];
+ int16_t xi = ((uint16_t) q * 3) >> 8;
+ *y++ = (float) (xi - 1) * d;
+ }
+ }
+ }
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
+ for (size_t n = 0; n < 5; ++n) {
+ for (size_t m = 0; m < 16; ++m) {
+ uint8_t q = x[i].qs[j + m] * pow3[n];
+ int16_t xi = ((uint16_t) q * 3) >> 8;
+ *y++ = (float) (xi - 1) * d;
+ }
+ }
+ }
+
+ for (size_t n = 0; n < 4; ++n) {
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
+ uint8_t q = x[i].qh[j] * pow3[n];
+ int16_t xi = ((uint16_t) q * 3) >> 8;
+ *y++ = (float) (xi - 1) * d;
+ }
+ }
+ }
+}
+
+void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int64_t i = 0; i < nb; ++i) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+ for (size_t l = 0; l < 4; ++l) {
+ for (size_t m = 0; m < 32; ++m) {
+ int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
+ *y++ = (float) (q - 1) * d;
+ }
+ }
+ }
+ }
+}
+
// ====================== "True" 2-bit (de)-quantization
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
*s = sumf;
}
+void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_tq1_0 * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+ float sumf = 0.0f;
+
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
+
+ const uint8x16_t shift = vld1q_u8(k_shift);
+
+ for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+ int32x4_t sumi0 = vdupq_n_s32(0);
+ int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+ int16x8_t sumi0 = vdupq_n_s16(0);
+ int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+ // first 32 bytes of 5 elements
+ {
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
+
+ // multiply by 3 and keep the 2 bits above 8 bits
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
+
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
+#else
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
+#endif
+ }
+
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+ {
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
+ qx5 = vmulq_u8(qx5, shift);
+
+ // multiply by 3 and keep the 2 bits above 8 bits
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+#else
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+#endif
+ }
+
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumi0 = vaddq_s32(sumi0, sumi1);
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+ sumf += d * (float) vaddvq_s32(sumi0);
+#else
+ sumi0 = vaddq_s16(sumi0, sumi1);
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+ sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+ }
+
+ *s = sumf;
+
+#elif defined(__AVX2__)
+ __m256 sumf = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ // 16-bit sums
+ __m256i sumi0 = _mm256_setzero_si256();
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+
+ // first 32 bytes of 5 elements
+ {
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
+ // 8-bit multiplies with shifts, masks and adds
+ __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
+ __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
+ __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
+ __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
+
+ // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
+
+ // Cancel the +1 from avg so that it behaves like a halving add
+ qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
+ qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
+ qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
+ qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
+ qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
+ // Multiply by 3 and get the top 2 bits
+ qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
+ qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
+ qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
+ qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
+ qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
+ qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
+ qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
+ qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
+ qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
+ qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
+
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
+ const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
+
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
+ qx4 = _mm256_maddubs_epi16(qx4, qy4);
+
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+ sumi2 = _mm256_add_epi16(sumi2, qx4);
+ }
+
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+ {
+ __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+ __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
+ __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
+ __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
+ __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
+ __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
+ __m256i qx01 = MM256_SET_M128I(qx1, qx0);
+ __m256i qx23 = MM256_SET_M128I(qx3, qx2);
+
+ // avx2 does not have 8-bit multiplies, so 16-bit it is.
+ qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
+ qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
+ __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
+
+ __m256i qx45 = MM256_SET_M128I(qx5, qx4);
+
+ // Cancel the +1 from avg so that it behaves like a halving add
+ qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
+ qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
+ qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
+ // Multiply by 3 and get the top 2 bits
+ qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
+ qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
+ qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
+ qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
+ qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
+ qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
+
+ const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
+ const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
+ const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
+
+ qx01 = _mm256_maddubs_epi16(qx01, qy01);
+ qx23 = _mm256_maddubs_epi16(qx23, qy23);
+ qx45 = _mm256_maddubs_epi16(qx45, qy45);
+
+ sumi0 = _mm256_add_epi16(sumi0, qx01);
+ sumi1 = _mm256_add_epi16(sumi1, qx23);
+ sumi2 = _mm256_add_epi16(sumi2, qx45);
+ }
+
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+ }
+
+ *s = hsum_float_8(sumf);
+
+#else
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
+
+ float sumf = 0.0f;
+
+ for (int i = 0; i < nb; ++i) {
+ int sum = 0;
+
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
+ for (size_t l = 0; l < 5; ++l) {
+ for (size_t m = 0; m < 32; ++m) {
+ uint8_t q = x[i].qs[j + m] * pow3[l];
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
+ }
+ }
+ }
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
+ for (size_t l = 0; l < 5; ++l) {
+ for (size_t m = 0; m < 16; ++m) {
+ uint8_t q = x[i].qs[j + m] * pow3[l];
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
+ }
+ }
+ }
+
+ for (size_t l = 0; l < 4; ++l) {
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
+ uint8_t q = x[i].qh[j] * pow3[l];
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
+ }
+ }
+
+ sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
+ }
+
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_tq2_0 * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+ float sumf = 0.0f;
+
+ const uint8x16_t m3 = vdupq_n_u8(3);
+
+ for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+ int32x4_t sumi0 = vdupq_n_s32(0);
+ int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+ int16x8_t sumi0 = vdupq_n_s16(0);
+ int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
+
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
+
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+#else
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+#endif
+ }
+
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumi0 = vaddq_s32(sumi0, sumi1);
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+ sumf += d * (float) vaddvq_s32(sumi0);
+#else
+ sumi0 = vaddq_s16(sumi0, sumi1);
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+ sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+ }
+
+ *s = sumf;
+
+#elif defined(__AVX2__)
+ __m256 sumf = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ // 16-bit sums, because 256*127 still fits
+ __m256i sumi0 = _mm256_setzero_si256();
+ __m256i sumi1 = _mm256_setzero_si256();
+
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
+ __m256i qx1 = _mm256_srli_epi16(qx0, 2);
+ __m256i qx2 = _mm256_srli_epi16(qx0, 4);
+ __m256i qx3 = _mm256_srli_epi16(qx0, 6);
+
+ // 0, 1, 2 (should not be 3)
+ qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
+ qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
+ qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
+ qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
+
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
+
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
+
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+ }
+
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+ sumi0 = _mm256_add_epi16(sumi0, sumi1);
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+ }
+
+ *s = hsum_float_8(sumf);
+
+#else
+ float sumf = 0.0f;
+
+ for (int i = 0; i < nb; ++i) {
+ int32_t sumi = 0;
+
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+ for (size_t l = 0; l < 4; ++l) {
+ for (size_t k = 0; k < 32; ++k) {
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
+ }
+ }
+ }
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ sumf += (float) sumi * d;
+ }
+
+ *s = sumf;
+#endif
+}
+
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
}
}
} break;
+ case GGML_TYPE_TQ1_0:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
+ } break;
+ case GGML_TYPE_TQ2_0:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
+ } break;
case GGML_TYPE_IQ1_S:
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);