]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (llama/11227)
authorfj-y-saito <redacted>
Thu, 16 Jan 2025 09:11:49 +0000 (18:11 +0900)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
* Add SVE support for q4_K_q8_K

* Update ggml/src/ggml-cpu/ggml-cpu-quants.c

change to use K_SCALE_SIZE

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cpu/ggml-cpu-quants.c

index 8e14722667abbc03f906cb96d801e8076592aa3a..88303ff0e61c7b5c8042bd48620d3070b17dbe34 100644 (file)
@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     uint32_t utmp[4];
 
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, K_SCALE_SIZE);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int vector_length = ggml_cpu_get_sve_cnt()*8;
+        const svuint8_t m4b = svdup_n_u8(0xf);
+        const svint32_t mzero = svdup_n_s32(0);
+        svint32_t sumi1 = svdup_n_s32(0);
+        svint32_t sumi1_1 = svdup_n_s32(0);
+        svint32_t sumi1_2 = svdup_n_s32(0);
+        svint32_t sumi2 = svdup_n_s32(0);
+        svint32_t sumi2_1 = svdup_n_s32(0);
+        svint32_t sumi2_2 = svdup_n_s32(0);
+        switch (vector_length) {
+            case 128:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+                        q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4 += 32;
+                    }
+                    sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
+                    sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
+                    sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
+                } break;
+            case 256:
+            case 512:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        const svuint8_t q4bits  = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
+                        q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                    }
+                    sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
+                } break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+    *s = sumf;
+#elif __ARM_NEON
     const uint8x16_t m4b = vdupq_n_u8(0xf);
     const int32x4_t mzero = vdupq_n_s32(0);