]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : speed-up ggml_vec_dot_q4_1() ARM_NEON + 32-bit ARM support (#900)
authorGeorgi Gerganov <redacted>
Thu, 13 Apr 2023 15:32:36 +0000 (18:32 +0300)
committerGitHub <redacted>
Thu, 13 Apr 2023 15:32:36 +0000 (18:32 +0300)
* ggml : speed-up q4_1 ARM_NEON by ~5%

* ggml : implement vaddvq when missing

* ggml : implement vminvq and vmaxvq when missing

* ggml : implement vzip when missing

* ggml : fix comment

* ggml : try to use correct ifdef

ggml.c

diff --git a/ggml.c b/ggml.c
index eb47d8298cae167315e9ca67ae8b79ad83e27018..b6a24b40c52fc6b239dbc643df5e894d2bb698da 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 }
 #endif
 
+#if __ARM_NEON
+
+#if !defined(__aarch64__)
+
+inline static uint16_t vaddvq_u8(uint8x16_t v) {
+    return
+        (uint16_t)vgetq_lane_u8(v, 0)  + (uint16_t)vgetq_lane_u8(v, 1)  +
+        (uint16_t)vgetq_lane_u8(v, 2)  + (uint16_t)vgetq_lane_u8(v, 3)  +
+        (uint16_t)vgetq_lane_u8(v, 4)  + (uint16_t)vgetq_lane_u8(v, 5)  +
+        (uint16_t)vgetq_lane_u8(v, 6)  + (uint16_t)vgetq_lane_u8(v, 7)  +
+        (uint16_t)vgetq_lane_u8(v, 8)  + (uint16_t)vgetq_lane_u8(v, 9)  +
+        (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
+        (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
+        (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
+}
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static uint32_t vaddvq_u16(uint16x8_t v) {
+    return
+        (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
+        (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
+        (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
+        (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline float vminvq_f32(float32x4_t v) {
+    return
+        MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
+    return vget_low_s8(vcombine_s8(a, b));
+}
+
+inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
+    return vget_high_s8(vcombine_s8(a, b));
+}
+
+inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+    return vget_low_u8(vcombine_u8(a, b));
+}
+
+inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+    return vget_high_u8(vcombine_u8(a, b));
+}
+
+#endif
+#endif
+
 // method 5
 // blocks of QK elements
 // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
 #define GGML_F32x4_ADD          vaddq_f32
 #define GGML_F32x4_MUL          vmulq_f32
-#if defined(__ARM_FEATURE_QRDMX)
-    #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#else
-    #define GGML_F32x4_REDUCE_ONE(x) \
-    (vgetq_lane_f32(x, 0) +          \
-     vgetq_lane_f32(x, 1) +          \
-     vgetq_lane_f32(x, 2) +          \
-     vgetq_lane_f32(x, 3))
-#endif
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
 #define GGML_F32x4_REDUCE(res, x)              \
 {                                              \
     for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         // 4-bit -> 8-bit
         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
         const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
-
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
         const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
 
         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
         const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
-
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
         const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
         const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
-
         const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
         const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
 
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
         const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
-
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
         const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
 
 #if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int16x8_t
+        // dot product into int32x4_t
         int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
         int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
 
         p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
         p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
 
-        // scalar
-#if defined(__ARM_FEATURE_QRDMX)
-        sum0 += x0->d * y0->d * vaddvq_s32(p_0);
-        sum1 += x1->d * y1->d * vaddvq_s32(p_1);
-#else
-        sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
-        sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
-#endif
+        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-
         const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
         const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
 
         const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
         const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-
         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
 
@@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
         const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
 
-        // scalar
-#if defined(__ARM_FEATURE_QRDMX)
-        sum0 += x0->d * y0->d * vaddvq_s16(p_0);
-        sum1 += x1->d * y1->d * vaddvq_s16(p_1);
-#else
-        sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
-        sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
-#endif
+        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
 #endif
     }
 
@@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     float sum10 = 0.0f;
     float sum11 = 0.0f;
 
-    for (int i = 0; i < nb; ++i) {
+    for (int i = 0; i < nb; i += 2) {
         const block_q4_1 * restrict x0 = &x[i + 0];
         const block_q4_1 * restrict y0 = &y[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q4_1 * restrict y1 = &y[i + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
-        // and with 0xf
+        // 4-bit -> 8-bit
         const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
         const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-
         const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
         const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
 
-        // dot product into uint16x8_t
+        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+
+        sum00 += x1->m*y1->m;
+        sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
+        sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
+        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
+
+        p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
+        p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
+
+        sum11 += x0->d*y0->d*vaddvq_s32(p_0);
+        sum11 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
         const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
         const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-
         const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
         const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
 
-        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
-        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
+        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
+        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
+        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
 
-        sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
-        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+
+        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
+        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+
+        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
+        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+
+        sum11 += x0->d*y0->d*vaddvq_u16(p_0);
+        sum11 += x1->d*y1->d*vaddvq_u16(p_1);
+#endif
     }
 
     sumf = QK*sum00 + sum01 + sum10 + sum11;