}
+#ifdef __ARM_FEATURE_SVE
+static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
+ const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
+ const svbool_t pg_false = svpfalse_b(); // 0x0000
+ const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
+ const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
+
+ svuint32_t vutmp_hi, vutmp_lo;
+ svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
+ vutmp_hi = svzip1_u32(vx01, vx01);
+ vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
+ vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
+ const svuint32_t vx2 = svdup_u32(vx_scales[2]);
+ vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
+ vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
+ svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
+ return vutmp;
+}
+#endif
+
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
+#ifdef __ARM_FEATURE_SVE
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
+#endif
-#if defined(__ARM_FEATURE_MATMUL_INT8)
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
+
+ const block_q4_K * GGML_RESTRICT vx0 = vx;
+ const block_q8_K * GGML_RESTRICT vy0 = vy;
+ const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
+ const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
+
+ union {
+ uint32_t u32[8];
+ uint64_t u64[4];
+ } new_utmp;
+
+ svfloat32_t sumf1 = svdup_n_f32(0);
+
+ switch (vector_length) {
+ case 128:
+ {
+ svbool_t pg_false = svpfalse_b();
+ svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
+ svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
+ svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
+ svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
+ for (int i = 0; i < nb; ++i) {
+ svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
+ svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
+ svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
+ svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
+ const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
+ const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
+ const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
+ const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
+ svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
+ svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
+ svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
+ svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
+ svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
+ lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
+ hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
+ sum_tmp1 = svuzp1(lo, hi);
+ sum_tmp2 = svuzp2(lo, hi);
+ svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
+ svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
+ svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
+ svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
+ svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
+ svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
+ svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
+ svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
+ svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
+ svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
+ svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
+ svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
+ svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
+ svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
+ svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
+ svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
+ svint32_t svscales, sumi1, sumi2;
+ svint32_t acc_sumif1 = svdup_n_s32(0);
+ svint32_t acc_sumif2 = svdup_n_s32(0);
+ svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
+ q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
+#pragma GCC unroll 1
+ for (int j = 0; j < QK_K/64; ++j) {
+ q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
+ q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
+ q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
+ q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
+ l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
+ l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
+ q8bytes_0_h = svld1_s8(pg128_all, q8_0);
+ q8bytes_1_h = svld1_s8(pg128_all, q8_1);
+ q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
+ q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
+ sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
+ acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
+
+ q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
+ q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
+ q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
+ q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
+ l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
+ l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
+ q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
+ q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
+ q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
+ q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
+ sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
+ acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
+ q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
+ }
+ sumf1 = svmla_f32_x(pg128_all,
+ svmla_f32_x(pg128_all,
+ sumf1,
+ svcvt_f32_x(pg128_all,
+ svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
+ svsuper_block_scales),
+ svdmins,
+ svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
+ } //end of for nb
+ } // end of case 128
+ break;
+ case 256:
+ case 512:
+ {
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
+ const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
+ const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
+ const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
+ const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
+ svint32_t svscales, sumi1, sumi2;
+ svint32_t acc_sumif1 = svdup_n_s32(0);
+ svint32_t acc_sumif2 = svdup_n_s32(0);
+ svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
+ svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
+ svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
+ svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
+ svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
+ svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
+ svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
+ svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
+ svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
+ svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
+ svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
+ svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
+ svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
+ svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
+ svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
+ svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
+ svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
+ svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
+ svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
+ svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
+ svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
+ svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
+ svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
+ svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
+ svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
+
+#pragma GCC unroll 1
+ for (int j = 0; j < QK_K/64; ++j) {
+ svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
+ svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
+ svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
+ svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
+ l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
+ l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
+ l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
+ l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
+ svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
+ svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
+ svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
+ svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
+ sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
+ acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
+ sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
+ acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
+ q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
+ }
+ svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
+ svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
+ acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
+ sumf1 = svmla_f32_x(pg32_4,
+ svmla_f32_x(pg32_4,
+ sumf1,
+ svcvt_f32_x(pg32_4, acc_sumif),
+ svsuper_block_scales),
+ svdmins,
+ svsumfs_tmp);
+ } // end of for nb
+ } // end of case 256-512
+ break;
+ default:
+ assert(false && "Unsupported vector length");
+ break;
+ }
+
+ svst1_f32(pg32_2, s, sumf1);
+ svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
+
+ return;
+ }
+#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_K * GGML_RESTRICT x0 = x;
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
- const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
const int nb = n / QK_K;
-#if defined(__ARM_FEATURE_MATMUL_INT8)
+#ifdef __ARM_FEATURE_SVE
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
+#endif
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
+
+ svfloat32_t sum = svdup_n_f32(0);
+
+ const block_q6_K * GGML_RESTRICT vx0 = vx;
+ const block_q8_K * GGML_RESTRICT vy0 = vy;
+ const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
+ const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
+
+ switch (vector_length) {
+ case 128:
+ {
+ const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
+ const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
+ const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
+ const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
+ const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
+ const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
+
+ const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
+ const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
+
+ svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
+ // process q8sum summation 128 bit route
+ const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
+ const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
+ const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
+ const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
+ const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
+ const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
+ const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
+ const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
+ const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
+ const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
+ const svint64_t prod = svdup_n_s64(0);
+
+ svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
+ svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
+ svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
+ svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
+ svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
+ svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
+ svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
+ svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
+ svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
+
+ // process mmla
+ svint8_t l0, l1, r0, r1;
+ svint32_t isum_tmp = svdup_n_s32(0);
+ for (int j = 0; j < QK_K/128; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
+ svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
+ svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
+ svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
+ const int ql_pos = (k/4)*4;
+ svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
+ svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
+ const int qh_pos = (k/2)*2;
+ svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
+ svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
+ svint8_t q6bytes_0, q6bytes_1;
+ if (qh_pos <= 4) {
+ q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
+ q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
+ } else {
+ q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
+ }
+ svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
+ svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
+ isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
+ }
+ qh0 += 32; qh1 += 32;
+ ql0 += 64; ql1 += 64;
+ q80 += 128; q81 += 128;
+ scale0 += 8; scale1 += 8;
+ }
+ sum = svmla_f32_x(pg128_all, sum,
+ svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
+ svisum_mins, svdup_n_s32(-32))),
+ svsuper_block_scales);
+ }
+ } // end of case 128
+ break;
+ case 256:
+ case 512:
+ {
+ const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
+ const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
+ const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
+ const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
+ const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
+ const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
+
+ const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
+ const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
+ svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
+ svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
+ // process q8sum summation 256 bit route
+ const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
+ const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
+ const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
+ const svint64_t prod = svdup_n_s64(0);
+ svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
+ svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
+ svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
+ svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
+ svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
+ svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
+ svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
+ svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
+ svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
+ svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
+ svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
+
+ // process mmla
+ svint8_t l0, l1, r0, r1;
+ svint32_t isum_tmp = svdup_n_s32(0);
+ for (int j = 0; j < QK_K/128; ++j) {
+ for (int k = 0; k < 8; k+=2) { // process 2 block
+ svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
+ svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
+ svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
+ svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
+ const int ql_pos = (k/4)*4;
+ svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
+ svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
+ const int qh_pos = (k/2)*2;
+ svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
+ svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
+ svint8_t q6bytes_0, q6bytes_1;
+ if (qh_pos <= 4) {
+ q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
+ q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
+ } else {
+ q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
+ }
+ svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
+ svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
+ svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
+ svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
+ isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
+ isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
+ }
+ qh0 += 32; qh1 += 32;
+ ql0 += 64; ql1 += 64;
+ q80 += 128; q81 += 128;
+ scale0 += 8; scale1 += 8;
+ } // end of for
+ svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
+ isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
+ sum = svmla_f32_x(pg32_4, sum,
+ svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
+ svisum_mins, svdup_n_s32(-32))),
+ svsuper_block_scales);
+ }
+ } // end of case 256
+ break;
+ default:
+ assert(false && "Unsupported vector length");
+ break;
+ } // end of switch
+
+ svst1_f32(pg32_2, s, sum);
+ svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
+
+ return;
+ }
+#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
// adjust bias, apply superblock scale
{
int32_t bias[4];
-#ifdef __ARM_FEATURE_SVE
- const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
- const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
- const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
- const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
- const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
- const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
- const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
- const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
- const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
- const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
- const svint64_t zero = svdup_n_s64(0);
- bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
- svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
- bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
- svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
- bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
- svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
- bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
- svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
-#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
-#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
#endif
#ifdef __ARM_FEATURE_SVE
- const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);