]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : use vzip instead of vuzp for consistency
authorGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 18:12:56 +0000 (21:12 +0300)
committerGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 18:12:56 +0000 (21:12 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index ebbaf11c620cc87fb661d400979d89b3f5d2c058..c9f0f09ea855b9b4a7c7123b1586e6b5f9b3c661 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2658,35 +2658,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
 
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
+
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        // interleave
-        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
-        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
 #else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
 
         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));