#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define UNUSED GGML_UNUSED
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
-static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
- int16x8_t * out_mins,
- int8_t * out_scales) {
+// Helper for decoding scales and mins of Q4_K and Q5_K block formats
+static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
+void ggml_gemv_q5_K_8x8_q8_K(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ constexpr int qk = QK_K;
+ const int nb = n / qk;
+
+ constexpr int ncols_interleaved = 8;
+ constexpr int blocklen = 8;
+
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ constexpr int col_pairs = ncols_interleaved / 2;
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ const uint8x16_t mone = vdupq_n_u8(1);
+ const uint8x16_t mtwo = vdupq_n_u8(2);
+
+ // 1x8 tile = 2 x 4
+ float32x4_t acc_f32[ncols_interleaved / 4];
+
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
+ acc_f32[i] = vdupq_n_f32(0);
+ }
+
+ for (int b = 0; b < nb; b++) {
+ float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
+ float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
+ float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
+ float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
+ float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
+ float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
+ float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
+ float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
+
+ // 2 sb each iteration
+ int32x4_t acc_lo[col_pairs];
+ int32x4_t acc_hi[col_pairs];
+
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+ int16_t bsums_arr[8];
+ vst1q_s16(bsums_arr, bsums);
+
+ // Load qh once per block and shift after each subblock
+ const uint8_t * qh_base = q5_ptr[b].qh;
+ uint8x16_t qh[col_pairs][4];
+ for (int cp = 0; cp < col_pairs; cp++) {
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+ }
+
+ for (int sb = 0; sb < QK_K / 64; sb++) {
+ for (int i = 0; i < col_pairs; i++) {
+ acc_lo[i] = vdupq_n_s32(0);
+ acc_hi[i] = vdupq_n_s32(0);
+ }
+ // Need scales for the low and high nibbles
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
+ int16x8_t q5sb_scales[2];
+ for (int i = 0; i < 2; i++) {
+ int8_t aux_q5sb[8];
+ const int offset = sb * 24 + i * 12;
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+ q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+ }
+
+ const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
+
+ // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
+ int8x16_t q8_qs[8];
+ for (int i = 0; i < 8; i++) {
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
+ }
+
+ // Q5s column pair loop unrolled
+ {
+ // Cols 01
+ uint8x16_t qs_0 = vld1q_u8(qs_base);
+ uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
+ uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
+ uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
+
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
+
+ qh[0][0] = vshrq_n_u8(qh[0][0], 2);
+ qh[0][1] = vshrq_n_u8(qh[0][1], 2);
+ qh[0][2] = vshrq_n_u8(qh[0][2], 2);
+ qh[0][3] = vshrq_n_u8(qh[0][3], 2);
+
+ acc_lo[0] = ggml_vdotq_s32(
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+ acc_lo[0] = ggml_vdotq_s32(
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+ acc_lo[0] = ggml_vdotq_s32(
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+ acc_lo[0] = ggml_vdotq_s32(
+ acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+ q8_qs[4]);
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+ q8_qs[5]);
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+ q8_qs[6]);
+ acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+ q8_qs[7]);
+
+ // Cols 23
+ qs_0 = vld1q_u8(qs_base + 16);
+ qs_1 = vld1q_u8(qs_base + 80);
+ qs_2 = vld1q_u8(qs_base + 144);
+ qs_3 = vld1q_u8(qs_base + 208);
+
+ hbit_lo_0 = vandq_u8(qh[1][0], mone);
+ hbit_lo_1 = vandq_u8(qh[1][1], mone);
+ hbit_lo_2 = vandq_u8(qh[1][2], mone);
+ hbit_lo_3 = vandq_u8(qh[1][3], mone);
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
+
+ qh[1][0] = vshrq_n_u8(qh[1][0], 2);
+ qh[1][1] = vshrq_n_u8(qh[1][1], 2);
+ qh[1][2] = vshrq_n_u8(qh[1][2], 2);
+ qh[1][3] = vshrq_n_u8(qh[1][3], 2);
+
+ acc_lo[1] = ggml_vdotq_s32(
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+ acc_lo[1] = ggml_vdotq_s32(
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+ acc_lo[1] = ggml_vdotq_s32(
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+ acc_lo[1] = ggml_vdotq_s32(
+ acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+ q8_qs[4]);
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+ q8_qs[5]);
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+ q8_qs[6]);
+ acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+ q8_qs[7]);
+
+ // Cols 45
+ qs_0 = vld1q_u8(qs_base + 32);
+ qs_1 = vld1q_u8(qs_base + 96);
+ qs_2 = vld1q_u8(qs_base + 160);
+ qs_3 = vld1q_u8(qs_base + 224);
+
+ hbit_lo_0 = vandq_u8(qh[2][0], mone);
+ hbit_lo_1 = vandq_u8(qh[2][1], mone);
+ hbit_lo_2 = vandq_u8(qh[2][2], mone);
+ hbit_lo_3 = vandq_u8(qh[2][3], mone);
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
+
+ qh[2][0] = vshrq_n_u8(qh[2][0], 2);
+ qh[2][1] = vshrq_n_u8(qh[2][1], 2);
+ qh[2][2] = vshrq_n_u8(qh[2][2], 2);
+ qh[2][3] = vshrq_n_u8(qh[2][3], 2);
+
+ acc_lo[2] = ggml_vdotq_s32(
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+ acc_lo[2] = ggml_vdotq_s32(
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+ acc_lo[2] = ggml_vdotq_s32(
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+ acc_lo[2] = ggml_vdotq_s32(
+ acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+ q8_qs[4]);
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+ q8_qs[5]);
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+ q8_qs[6]);
+ acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+ q8_qs[7]);
+
+ // Cols 45
+ qs_0 = vld1q_u8(qs_base + 48);
+ qs_1 = vld1q_u8(qs_base + 112);
+ qs_2 = vld1q_u8(qs_base + 176);
+ qs_3 = vld1q_u8(qs_base + 240);
+
+ hbit_lo_0 = vandq_u8(qh[3][0], mone);
+ hbit_lo_1 = vandq_u8(qh[3][1], mone);
+ hbit_lo_2 = vandq_u8(qh[3][2], mone);
+ hbit_lo_3 = vandq_u8(qh[3][3], mone);
+ hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
+ hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
+ hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
+ hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
+
+ qh[3][0] = vshrq_n_u8(qh[3][0], 2);
+ qh[3][1] = vshrq_n_u8(qh[3][1], 2);
+ qh[3][2] = vshrq_n_u8(qh[3][2], 2);
+ qh[3][3] = vshrq_n_u8(qh[3][3], 2);
+
+ acc_lo[3] = ggml_vdotq_s32(
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+ acc_lo[3] = ggml_vdotq_s32(
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+ acc_lo[3] = ggml_vdotq_s32(
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+ acc_lo[3] = ggml_vdotq_s32(
+ acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+ q8_qs[4]);
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+ q8_qs[5]);
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+ q8_qs[6]);
+ acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+ q8_qs[7]);
+ }
+
+ // Prepare bsum vectors for bias computation
+ // Each pair of subblocks share the same bsums
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
+ // p = 0 -> 0123 p2 -> 4567
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
+ int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
+ int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
+ float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
+
+ // 0123 or 4567
+ float32x4_t sumf_0 =
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
+
+ float32x4_t sumf_1 =
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
+
+ // FUSED BIAS: Compute and subtract bias immediately
+ // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
+ int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
+ bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
+ float32x4_t bias_f32 = vcvtq_f32_s32(bias);
+ acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
+ }
+ } // for sb
+ } // for b
+
+ int base = x * ncols_interleaved;
+ vst1q_f32(s + base, acc_f32[0]);
+ vst1q_f32(s + base + 4, acc_f32[1]);
+ } // for x
+ return;
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
void ggml_gemv_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
for (int i = 0; i < 2; i++) {
const int offset = sb * 24 + i * 12;
- decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
+ decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
}
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
+void ggml_gemm_q5_K_8x8_q8_K(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ constexpr int qk = QK_K;
+ const int nb = n / qk;
+
+ constexpr int ncols_interleaved = 8;
+ constexpr int blocklen = 8;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ constexpr int q8_k_blocklen = 4;
+ constexpr int col_pairs = ncols_interleaved / 2;
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ const uint8x16_t mone = vdupq_n_u8(1);
+ const uint8x16_t mtwo = vdupq_n_u8(2);
+
+ // 8 accumulators: 2 row pairs Ă— 4 col pairs
+ float32x4_t acc_f32[blocklen];
+
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+ for (int i = 0; i < blocklen; i++) {
+ acc_f32[i] = vdupq_n_f32(0);
+ }
+
+ for (int b = 0; b < nb; b++) {
+ // bsums pairs belongs to the same q8_k subblock
+ const int16x8_t bsums[4]{
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+ };
+ int16_t bsums_arr[4][8];
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+ }
+
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vdupq_n_s32(0);
+ bias_acc[i] = vdupq_n_s32(0);
+ }
+
+ // Load qh once per block and shift after each subblock
+ const uint8_t * qh_base = q5_ptr[b].qh;
+ uint8x16_t qh[col_pairs][4];
+ for (int cp = 0; cp < col_pairs; cp++) {
+ qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+ qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+ qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+ qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+ }
+
+ for (int sb = 0; sb < QK_K / 64; sb++) {
+ // Need scales for the low and high nibbles
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+ int8_t q5sb_scales[2][8];
+ int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
+ for (int i = 0; i < 2; i++) {
+ const int offset = sb * 24 + i * 12;
+ decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
+ }
+
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+ int8x16_t q8_qs_01[8];
+ int8x16_t q8_qs_23[8];
+
+ // Load 32-byte per row pair, 1 subblock each time
+ for (int i = 0; i < 8; i++) {
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
+ }
+
+ const int8x16_t q8s[2][8] = {
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
+ q8_qs_01[7] },
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
+ q8_qs_23[7] },
+ };
+
+ // Q5s columns iterated in pairs (01, 23, 45, 67)
+ for (int cp = 0; cp < col_pairs; cp++) {
+ for (int i = 0; i < 4; i++) {
+ sb_acc[i] = vdupq_n_s32(0);
+ }
+
+ uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
+ uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
+ uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
+ uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
+
+ // This is the only part of the algorithm that differs with Q4_K
+ // Extract High bits and pack into 5 bit weights
+ uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
+ uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
+ qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
+ // Same as Q4_K, i8mm to dequantize the weights.
+ const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
+ int32x4_t acc_0 = sb_acc[0];
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
+ int32x4_t acc_2 = sb_acc[2];
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
+ const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
+ int32x4_t acc_1 = sb_acc[1];
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
+ int32x4_t acc_3 = sb_acc[3];
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
+
+ // Repeat for the other 3 columns (8..15, 16..23, 24..31)
+ uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
+ uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
+ qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
+ const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
+ const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
+
+ uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
+ uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
+ qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
+ const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
+ const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
+
+ uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
+ uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
+ qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
+ const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
+ acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
+ sb_acc[0] = acc_0;
+ acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
+ sb_acc[2] = acc_2;
+
+ // Scales[i] corresponds to column i
+ const int scale_offset = cp * 2;
+ const int32_t s0 = q5sb_scales[0][scale_offset];
+ const int32_t s1 = q5sb_scales[0][scale_offset + 1];
+ const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
+
+ const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
+ acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
+ sb_acc[1] = acc_1;
+ acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
+ sb_acc[3] = acc_3;
+
+ const int32_t s2 = q5sb_scales[1][scale_offset];
+ const int32_t s3 = q5sb_scales[1][scale_offset + 1];
+ const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
+ }
+
+ // Multiply Acc bsum + mins
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
+ // Each pair of subblocks share the same bsums
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+ bias_acc[2 * q8_row] =
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+ bias_acc[2 * q8_row] =
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+ bias_acc[2 * q8_row + 1] =
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+ bias_acc[2 * q8_row + 1] =
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+ }
+ } // for sb
+
+ // Reorder of i8mm output with bias and output layout
+ for (int i = 0; i < 8; i++) {
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
+ }
+ int32x4_t reorder_acc[8] = {
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+ };
+
+ for (int i = 0; i < q8_k_blocklen; i++) {
+ for (int j = 0; j < 2; j++) {
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
+ float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
+ const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
+
+ float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
+ const float32x4_t scale = vmulq_f32(q5_d, q8_d);
+
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+ acc_f32[2 * i + j] =
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+ }
+ }
+ } // for b
+
+ // With the previous reorder, the tile is already in the correct memory layout.
+ for (int i = 0; i < q8_k_blocklen; i++) {
+ int row = y * q8_k_blocklen + i;
+ for (int j = 0; j < 2; j++) {
+ int col = x * ncols_interleaved + j * 4;
+ int offset = row * bs + col;
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
+ }
+ }
+ } // for x
+ } // for y
+ return;
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
void ggml_gemm_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
- UNUSED(s);
UNUSED(bs);
- UNUSED(vx);
- UNUSED(vy);
UNUSED(nr);
- UNUSED(nc);
- UNUSED(nb);
- UNUSED(ncols_interleaved);
- UNUSED(blocklen);
float sumf[8];
float sum_minf[8];
}
}
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK_K;
+ const int nb = n / qk;
+ const int ncols_interleaved = 8;
+ const int blocklen = 8;
+ static const uint32_t kmask1 = 0x3f3f3f3f;
+ static const uint32_t kmask2 = 0x0f0f0f0f;
+ static const uint32_t kmask3 = 0x03030303;
+
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(bs);
+ UNUSED(nr);
+
+ float sumf[8];
+ float sum_minf[8];
+ uint32_t utmp[32];
+ int sumi1;
+ int sumi2;
+ int sumi;
+
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[j] = 0.0;
+ sum_minf[j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int sb = 0; sb < 8; sb++) {
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+ utmp[sb * 4 + 2] = uaux_0;
+ utmp[sb * 4 + 0] &= kmask1;
+ }
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
+
+ const int qh_shift = (k / 4) * 2;
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi1 = 0;
+ sumi2 = 0;
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+ const int qh_idx = (k * 8 + i) % 32;
+ const int qh_chunk = qh_idx / 8;
+ const int qh_pos = qh_idx % 8;
+ const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
+
+ const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+ const uint8_t h0 = (qh_val >> qh_shift) & 1;
+ const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
+
+ const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+ const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+ const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
+
+ sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+ sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
+ sumi1 = sumi1 * scales_0[j];
+ sumi2 = sumi2 * scales_1[j];
+ sumi += sumi1 + sumi2;
+ }
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+ }
+ }
+ for (int sb = 0; sb < 8; sb++) {
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
+ GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+ }
+ }
+}
+
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
}
}
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK_K;
+ const int nb = n / qk;
+ const int ncols_interleaved = 8;
+ const int blocklen = 8;
+
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
+ constexpr uint32_t kmask3 = 0x03030303;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ float sumf[4][8];
+ float sum_minf[4][8];
+ uint32_t utmp[32];
+ int sumi1;
+ int sumi2;
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[m][j] = 0.0;
+ sum_minf[m][j] = 0.0;
+ }
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int sb = 0; sb < 8; sb++) {
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+ utmp[sb * 4 + 2] = uaux_0;
+ utmp[sb * 4 + 0] &= kmask1;
+ }
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
+
+ const int qh_shift = (k / 4) * 2;
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi1 = 0;
+ sumi2 = 0;
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+ const int qh_idx = (k * 8 + i) % 32;
+ const int qh_chunk = qh_idx / 8;
+ const int qh_pos = qh_idx % 8;
+ const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
+
+ const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+ const uint8_t h0 = (qh_val >> qh_shift) & 1;
+ const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
+
+ const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+ const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+ const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
+
+ sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+ sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
+ sumi1 = sumi1 * scales_0[j];
+ sumi2 = sumi2 * scales_1[j];
+ sumi += sumi1 + sumi2;
+ }
+ sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+ }
+ }
+ }
+ for (int sb = 0; sb < 8; sb++) {
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+ for (int m = 0; m < 4; m++) {
+ const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
+ GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+ }
+ }
+ }
+ }
+}
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
out.scales[i] = in[src1].scales[src2];
}
return out;
+}
+
+static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
+ block_q5_Kx8 out;
+ //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
+ for (int i = 0; i < 8; i++) {
+ out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+ }
+
+ for (int i = 0; i < 8; i++) {
+ out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+ }
+ const int end = QK_K * 4 / blck_size_interleave;
+
+ // Interleave Q5_K quants by taking 8 bytes at a time
+ for (int i = 0; i < end; ++i) {
+ int src_id = i % 8;
+ int src_offset = (i / 8) * blck_size_interleave;
+ int dst_offset = i * blck_size_interleave;
+
+ uint64_t elems;
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+ }
+
+ // Repeat for low bits 8 bytes at a time as well, since
+ // the high bits are interleaved in Q5_K and the index is
+ // qh_idx = (qs_idx % 32);
+ // qh_val = qh[qh_idx] >> (qs_idx / 32);
+ for (int i = 0; i < end / 4; ++i) {
+ int src_id = i % 8;
+ int src_offset = (i / 8) * blck_size_interleave;
+ int dst_offset = i * blck_size_interleave;
+
+ uint64_t elems;
+ memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
+ memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
+ }
+
+ // The below logic is copied over from Q4_K
+ // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
+ // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
+ // The output Q5_Kx8 structure has 96 bytes
+ // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
+ // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
+ uint8_t s[8], m[8];
+
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 8; j++) {
+ s[j] = in[j].scales[i] & 63;
+ m[j] = in[j].scales[i + 4] & 63;
+ }
+
+ out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
+ out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
+ out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
+ out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
+ out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
+ out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
+ out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
+ out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
+ out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
+ out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
+ out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
+ out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 8; j++) {
+ s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
+ m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
+ }
+
+ out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
+ out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
+ out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
+ out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
+ out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
+ out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
+ out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
+ out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
+ out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
+ out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
+ out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
+ out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
+ }
+
+ return out;
}
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_UNUSED(data_size);
}
+static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
+ int interleave_block,
+ const void * GGML_RESTRICT data,
+ size_t data_size) {
+ GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
+ GGML_ASSERT(interleave_block == 8);
+ constexpr int nrows_interleaved = 8;
+
+ block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
+ const block_q5_K * src = (const block_q5_K *) data;
+ block_q5_K dst_tmp[8];
+ int nrow = ggml_nrows(t);
+ int nblocks = t->ne[0] / QK_K;
+
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
+
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+ return -1;
+ }
+
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
+ for (int64_t x = 0; x < nblocks; x++) {
+ for (int i = 0; i < nrows_interleaved; i++) {
+ dst_tmp[i] = src[x + i * nblocks];
+ }
+ *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
+ }
+ src += nrows_interleaved * nblocks;
+ }
+ return 0;
+}
+
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8);
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
+template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
+}
+
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
}
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
+template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
- ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
- ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
-}
-
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
+template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
- ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
+ // instance for Q5_K
+ static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
+
// instance for Q2
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
return &q2_K_8x8_q8_K;
}
}
+ } else if (cur->type == GGML_TYPE_Q5_K) {
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ if (cur->ne[1] % 8 == 0) {
+ return &q5_K_8x8_q8_K;
+ }
+ }
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {
};
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+
struct block_q2_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
ggml_half dmin[8]; // super-block scale for quantized mins
};
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
+
+struct block_q5_Kx8 {
+ ggml_half d[8]; // super-block scale for quantized scales
+ ggml_half dmin[8]; // super-block scale for quantized mins
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
+ uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
+};
+
+static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
+ "wrong q5_K block size/padding");
+
struct block_q8_Kx4 {
float d[4]; // delta
int8_t qs[QK_K * 4]; // quants
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);