From: Supreet Sethi Date: Thu, 30 Mar 2023 17:25:29 +0000 (+0800) Subject: ggml : fix NEON sign types (#51) X-Git-Tag: upstream/0.0.1642~1566 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=652c0c01521e8eb8a977463dce865839259cd02f;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix NEON sign types (#51) --- diff --git a/src/ggml.c b/src/ggml.c index 02675ee6..27b246d0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1038,8 +1038,8 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8x16_t vq = vcombine_u8(vx_0, vx_1); // convert to 2x uint16x8_t - const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); - const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); + const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq)); + const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq)); // convert to 4x float32x4_t const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index be7b038d..cdca93ed 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -810,7 +810,7 @@ void quantize_3_row(const float * restrict src, void * restrict dst, int k) { for (int b = 0; b < QB; ++b) { uint32x4_t m = vdupq_n_u32(1 << b); - uint32x4_t r = vdupq_n_u32(-b); + int32x4_t r = vdupq_n_s32(-b); if (l < 32) { p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0))); @@ -1887,24 +1887,24 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons const uint8_t * restrict p0 = pb0 + i*QK/2; const uint8_t * restrict p1 = pb1 + i*QK/2; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t m4b = vdupq_n_s8(0xf); const int8x16_t s8b = vdupq_n_s8(0x8); - const uint8x16_t v0_0 = vld1q_u8(p0); - const uint8x16_t v0_1 = vld1q_u8(p0 + 16); - const uint8x16_t v1_0 = vld1q_u8(p1); - const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + const int8x16_t v0_0 = vld1q_s8(p0); + const int8x16_t v0_1 = vld1q_s8(p0 + 16); + const int8x16_t v1_0 = vld1q_s8(p1); + const int8x16_t v1_1 = vld1q_s8(p1 + 16); // 4-bit -> 8-bit - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + const int8x16_t v0_0l = vandq_s8(v0_0, m4b); + const int8x16_t v0_1l = vandq_s8(v0_1, m4b); + const int8x16_t v1_0l = vandq_s8(v1_0, m4b); + const int8x16_t v1_1l = vandq_s8(v1_1, m4b); - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4); + const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4); + const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4); + const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4); // sub 8 const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); @@ -2280,26 +2280,26 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons const uint8_t * restrict p0 = pb0 + i*16; const uint8_t * restrict p1 = pb1 + i*16; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t m4b = vdupq_n_s8(0xf); const int8x16_t s8b = vdupq_n_s8(0x8); - const uint8x16_t v0_0 = vld1q_u8(p0); - const uint8x16_t v0_1 = vld1q_u8(p0 + 16); - const uint8x16_t v1_0 = vld1q_u8(p1); - const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + const int8x16_t v0_0 = vld1q_s8(p0); + const int8x16_t v0_1 = vld1q_s8(p0 + 16); + const int8x16_t v1_0 = vld1q_s8(p1); + const int8x16_t v1_1 = vld1q_s8(p1 + 16); // 4-bit -> 8-bit - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + const int8x16_t v0_0l = vandq_s8(v0_0, m4b); + const int8x16_t v1_0l = vandq_s8(v1_0, m4b); - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4); + const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4); - const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); - const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + const int8x16_t v0_1l = vandq_s8(v0_1, m4b); + const int8x16_t v1_1l = vandq_s8(v1_1, m4b); - const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); - const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4); + const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4); // sub 8 const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);