uint32_t utmp[4];
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
+
+ uint32x2_t mins8 = { 0 };
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[0] &= kmask1;
+
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+ sumf -= dmin * vaddvq_s32(prod);
+
+ const uint8_t * scales = (const uint8_t *)utmp;
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
+ const svuint8_t m4b = svdup_n_u8(0xf);
+ const svint32_t mzero = svdup_n_s32(0);
+ svint32_t sumi1 = svdup_n_s32(0);
+ svint32_t sumi1_1 = svdup_n_s32(0);
+ svint32_t sumi1_2 = svdup_n_s32(0);
+ svint32_t sumi2 = svdup_n_s32(0);
+ svint32_t sumi2_1 = svdup_n_s32(0);
+ svint32_t sumi2_2 = svdup_n_s32(0);
+ switch (vector_length) {
+ case 128:
+ {
+ for (int j = 0; j < QK_K/64; ++j) {
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+ q4 += 32;
+ }
+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
+ } break;
+ case 256:
+ case 512:
+ {
+ for (int j = 0; j < QK_K/64; ++j) {
+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+ }
+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
+ } break;
+ default:
+ assert(false && "Unsupported vector length");
+ break;
+ }
+ }
+ *s = sumf;
+#elif __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);