uint8_t qs[QK4_NL/2];
} block_iq4_nl;
+#if QK_K == 64
+#define block_iq4_xs block_iq4_nl
+#else
typedef struct {
half d;
uint16_t scales_h;
uint8_t scales_l[QK_K/64];
uint8_t qs[QK_K/2];
} block_iq4_xs;
+#endif
//====================================== dot products =========================
threadgroup_barrier(mem_flags::mem_threadgroup);
}
-#if QK_K == 256
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
y4 += 32 * 32;
}
-#else
- (void) x;
- (void) y;
- (void) yl;
- (void) nb32;
-#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
-#if QK_K == 256
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
y4 += 32 * 32;
}
-#else
- (void) x;
- (void) y;
- (void) yl;
- (void) nb32;
-#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
-#if QK_K == 256
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
y4 += 32 * 32;
}
-#else
- (void) x;
- (void) y;
- (void) yl;
- (void) nb32;
-#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
const int nb32 = nb * (QK_K / 32);
-#if QK_K == 256
const int ix = tiisg/2;
const int il = tiisg%2;
y4 += 16 * 32;
}
-#else
- (void) x;
- (void) y;
- (void) yl;
- (void) nb32;
-#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
}
}
+#if QK_K != 64
void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
}
}
}
+#endif
[[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+#if QK_K == 64
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+#else
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+#endif
}
//============================= templates and their specializations =============================
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+#if QK_K == 64
+ dequantize_iq4_nl(xb, il, reg);
+#else
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
+#endif
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
+#if QK_K == 64
+template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
+#else
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+#endif
//
// matrix-matrix multiplication
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
+#if QK_K == 64
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
+#else
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+#endif
//
// indirect matrix-matrix multiplication
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
+#if QK_K == 64
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
+#else
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+#endif
//
// matrix-vector multiplication
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+#if QK_K == 64
+ kernel_mul_mv_iq4_nl_f32_impl(
+#else
kernel_mul_mv_iq4_xs_f32_impl(
+#endif
src0[id],
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
float mins[QK_K/16];
float scales[QK_K/16];
float sw[QK_K/16];
- float weight[QK_K/16];
+ float weight[16];
uint8_t Ls[QK_K/16], Lm[QK_K/16];
for (int i = 0; i < nb; i++) {
float sigma2 = sumx2/QK_K;
for (int j = 0; j < QK_K/16; ++j) {
const float * restrict qw = quant_weights + QK_K * i + 16*j;
- for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
- scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+ scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
}
- float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
- float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
+ float dm, mm;
+#if QK_K == 64
+ float max_scale = 0, max_min = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ max_scale = MAX(max_scale, scales[j]);
+ max_min = MAX(max_min, mins[j]);
+ }
+ dm = max_scale/15;
+ mm = max_min/15;
+ if (max_scale) {
+ float id = 1/dm;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = nearest_int(id*scales[j]);
+ Ls[j] = MAX(0, MIN(15, l));
+ }
+ } else {
+ memset(Ls, 0, QK_K/16);
+ }
+ if (max_min) {
+ float id = 1/mm;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = nearest_int(id*mins[j]);
+ Lm[j] = MAX(0, MIN(15, l));
+ }
+ } else {
+ memset(Lm, 0, QK_K/16);
+ }
+#else
+ dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
+ mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
+#endif
y[i].d = GGML_FP32_TO_FP16(dm);
y[i].dmin = GGML_FP32_TO_FP16(mm);
dm = GGML_FP16_TO_FP32(y[i].d);
void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
+#if QK_K == 64
+ dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
+#else
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
qs += 16;
}
}
+#endif
}
//===================================== Q8_K ==============================================
float sumf = 0;
- int isum[4];
+ int isum[QK_K/16];
for (int i = 0; i < nb; ++i) {
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
- isum[0] = isum[1] = isum[2] = isum[3] = 0;
+ memset(isum, 0, (QK_K/16)*sizeof(int));
for (int l = 0; l < 16; ++l) {
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
}
- for (int l = 0; l < 4; ++l) {
+ for (int l = 0; l < QK_K/16; ++l) {
isum[l] *= (sc[l] & 0xF);
}
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
#elif defined(__AVX2__)
- const __m128i m4 = _mm_set1_epi8(0xf);
- const __m128i m1 = _mm_set1_epi8(1);
- const __m256i m511 = _mm256_set1_epi16(511);
const __m256i mone = _mm256_set1_epi8(1);
-
- static const uint8_t k_bit_helper[32] = {
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- };
static const char block_sign_shuffle_mask_1[32] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
- const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
+#if QK_K == 64
+ static const uint8_t k_bit_helper[16] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper);
+ const __m128i m511 = _mm_set1_epi16(511);
+ typedef union {
+ __m128i vec_index;
+ uint16_t index[8];
+ } index_t;
+
+ index_t idx;
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs);
+ idx.vec_index = _mm_and_si128(q2_data, m511);
+
+ const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9);
+ const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13);
+ const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper);
+
+ const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
+ const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
+ const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
+
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
+
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
+ iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
+ iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
+
+ __m256i signs;
+ signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
+
+ signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
+
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+
+ const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
+ const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
+
+ const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+#else
+
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
+ const __m256i m511 = _mm256_set1_epi16(511);
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m1 = _mm_set1_epi8(1);
+
uint64_t aux64;
// somewhat hacky, but gives a significant boost in performance
}
*s = 0.125f * hsum_float_8(accumf);
+#endif
#else
const int nb = n / QK_K;
-#if defined __ARM_NEON
+ // TODO: implement for QK_K = 64
+#if defined __ARM_NEON && QK_K == 256
const uint8x16_t m8 = vdupq_n_u8(0x08);
const uint8x16_t m7 = vdupq_n_u8(0x07);
*s = sumf;
-#elif defined __AVX2__
+ // TODO: implement for QK_K = 64
+#elif defined __AVX2__ && QK_K == 256
const __m128i m8 = _mm_set1_epi8(0x08);
const __m128i m7 = _mm_set1_epi8(0x07);
UNUSED(by);
UNUSED(bs);
assert(n % QK_K == 0);
+#if QK_K == 64
+ ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc);
+#else
const block_iq4_xs * restrict x = vx;
const block_q8_K * restrict y = vy;
}
*s = sumf;
#endif
+#endif
}
// ================================ IQ2 quantization =============================================
const int kMaxQ = 3;
- const int nbl = n/256;
+ const int nbl = n/QK_K;
block_iq2_xxs * y = vy;
const int kMaxQ = 3;
- const int nbl = n/256;
+ const int nbl = n/QK_K;
block_iq2_xs * y = vy;
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(n%QK_K == 0);
- const int nbl = n/256;
+ const int nbl = n/QK_K;
block_iq1_s * y = vy;
}
size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+#if QK_K == 64
+ return quantize_iq4_nl(src, dst, nrow, n_per_row, hist, quant_weights);
+#else
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
qrow += nblock*sizeof(block_iq4_xs);
}
return nrow * nblock * sizeof(block_iq4_xs);
+#endif
}
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
const int kMaxQ = 3;
- const int nbl = n/256;
+ const int nbl = n/QK_K;
block_iq2_s * y = vy;