const int nb = n / QK_K;
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
+ float sum = 0;
+ svuint8_t m4b = svdup_n_u8(0xf);
+ svint32_t vzero = svdup_n_s32(0);
+ svuint8_t mone = svdup_n_u8(0x30);
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
+
+ for (int i = 0; i < nb; ++i) {
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
+
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
+
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
+ const svint64_t prod = svdup_n_s64(0);
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
+ svdot_s64(prod, q8sums_2, q6scales_2)));
+ int32_t isum = 0;
+
+ switch (vector_length) {
+ case 128:
+ {
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
+ svint32_t isum_tmp = svdup_n_s32(0);
+ for (int j = 0; j < QK_K/128; ++j) {
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
+ qh += 32;
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
+ q6 += 64;
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
+ q8 += 64;
+
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
+
+ scale += 4;
+ q8bytes_1 = svld1_s8(pg8_16, q8);
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
+ q8 += 64;
+
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
+ scale += 4;
+ }
+ isum += svaddv_s32(pg32_4, isum_tmp);
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
+ }
+ break;
+ case 256:
+ case 512:
+ {
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
+ svint32_t isum_tmp = svdup_n_s32(0);
+ for (int j = 0; j < QK_K/128; j++) {
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
+ qh += 32;
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
+ q6 += 64;
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
+ q8 += 128;
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
+
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
+
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
+ scale += 8;
+ }
+ isum += svaddv_s32(pg32_8, isum_tmp);
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
+ }
+ break;
+ default:
+ assert(false && "Unsupported vector length");
+ break;
+ }
+ }
+
+ *s = sum;
+
+#elif __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);