]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ARM_NEON ggml_vec_dot_q4_1()
authorGeorgi Gerganov <redacted>
Wed, 29 Mar 2023 18:47:33 +0000 (21:47 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 Mar 2023 19:03:07 +0000 (22:03 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index c049f00a939d47401d7a80b4314035522733ff2f..0906cf90ec3787d046314cf5eed08e4facfd9e27 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2008,6 +2008,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
+#elif defined(__ARM_NEON)
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict y0 = &y[i + 0];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+
+        // and with 0xf
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+
+        // dot product into uint16x8_t
+        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
+        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
+
+        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
+        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+
+        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+    }
+
+    sumf = QK*sum00 + sum01 + sum10 + sum11;
 #else
     // scalar
     for (int i = 0; i < nb; i++) {