]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix q4_1 dot product types (#759)
authornovag <redacted>
Fri, 14 Apr 2023 10:34:20 +0000 (12:34 +0200)
committerGitHub <redacted>
Fri, 14 Apr 2023 10:34:20 +0000 (13:34 +0300)
Co-authored-by: Georgi Gerganov <redacted>
ggml.c

diff --git a/ggml.c b/ggml.c
index 42e3ee314424d5f8dada77990a03c79ef24b6c58..8664054d1f558d7404dcef719ba67e6511f87b97 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2344,14 +2344,14 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
 
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
-        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
+        uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
+        uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
 
-        p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
-        p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
+        p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
+        p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
 
-        sum11 += x0->d*y0->d*vaddvq_s32(p_0);
-        sum11 += x1->d*y1->d*vaddvq_s32(p_1);
+        sum11 += x0->d*y0->d*vaddvq_u32(p_0);
+        sum11 += x1->d*y1->d*vaddvq_u32(p_1);
 #else
         const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
         const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));