]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix NEON sign types (#51)
authorSupreet Sethi <redacted>
Thu, 30 Mar 2023 17:25:29 +0000 (01:25 +0800)
committerGitHub <redacted>
Thu, 30 Mar 2023 17:25:29 +0000 (20:25 +0300)
src/ggml.c
tests/test-mul-mat2.c

index 02675ee67072d7668c3c018733cda1ca18927688..27b246d09bcff30113f0d490eb496a5b067bf34a 100644 (file)
@@ -1038,8 +1038,8 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
 
             // convert to 2x uint16x8_t
-            const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq));
-            const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq));
+            const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
+            const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
 
             // convert to 4x float32x4_t
             const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
index be7b038dfb1629b41401b854ee97bc31409cdc4d..cdca93edfb9f8a8411f2de99a1552c7daa380414 100644 (file)
@@ -810,7 +810,7 @@ void quantize_3_row(const float * restrict src, void * restrict dst, int k) {
 
                     for (int b = 0; b < QB; ++b) {
                         uint32x4_t m = vdupq_n_u32(1 << b);
-                        uint32x4_t r = vdupq_n_u32(-b);
+                        int32x4_t r = vdupq_n_s32(-b);
 
                         if (l < 32) {
                             p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
@@ -1887,24 +1887,24 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
         const uint8_t * restrict p0 = pb0 + i*QK/2;
         const uint8_t * restrict p1 = pb1 + i*QK/2;
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t m4b = vdupq_n_s8(0xf);
         const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-        const uint8x16_t v0_0 = vld1q_u8(p0);
-        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
-        const uint8x16_t v1_0 = vld1q_u8(p1);
-        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+        const int8x16_t v0_0 = vld1q_s8(p0);
+        const int8x16_t v0_1 = vld1q_s8(p0 + 16);
+        const int8x16_t v1_0 = vld1q_s8(p1);
+        const int8x16_t v1_1 = vld1q_s8(p1 + 16);
 
         // 4-bit -> 8-bit
-        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
-        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
-        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+        const int8x16_t v0_0l = vandq_s8(v0_0, m4b);
+        const int8x16_t v0_1l = vandq_s8(v0_1, m4b);
+        const int8x16_t v1_0l = vandq_s8(v1_0, m4b);
+        const int8x16_t v1_1l = vandq_s8(v1_1, m4b);
 
-        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
-        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
-        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
-        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+        const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4);
+        const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4);
+        const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4);
+        const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4);
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
@@ -2280,26 +2280,26 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons
         const uint8_t * restrict p0 = pb0 + i*16;
         const uint8_t * restrict p1 = pb1 + i*16;
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t m4b = vdupq_n_s8(0xf);
         const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-        const uint8x16_t v0_0 = vld1q_u8(p0);
-        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
-        const uint8x16_t v1_0 = vld1q_u8(p1);
-        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+        const int8x16_t v0_0 = vld1q_s8(p0);
+        const int8x16_t v0_1 = vld1q_s8(p0 + 16);
+        const int8x16_t v1_0 = vld1q_s8(p1);
+        const int8x16_t v1_1 = vld1q_s8(p1 + 16);
 
         // 4-bit -> 8-bit
-        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
-        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+        const int8x16_t v0_0l = vandq_s8(v0_0, m4b);
+        const int8x16_t v1_0l = vandq_s8(v1_0, m4b);
 
-        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
-        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+        const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4);
+        const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4);
 
-        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
-        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+        const int8x16_t v0_1l = vandq_s8(v0_1, m4b);
+        const int8x16_t v1_1l = vandq_s8(v1_1, m4b);
 
-        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
-        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+        const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4);
+        const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4);
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);