for (int b = 0; b < QB; ++b) {
uint32x4_t m = vdupq_n_u32(1 << b);
- uint32x4_t r = vdupq_n_u32(-b);
+ int32x4_t r = vdupq_n_s32(-b);
if (l < 32) {
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
- const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const int8x16_t m4b = vdupq_n_s8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
- const uint8x16_t v0_0 = vld1q_u8(p0);
- const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
- const uint8x16_t v1_0 = vld1q_u8(p1);
- const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+ const int8x16_t v0_0 = vld1q_s8(p0);
+ const int8x16_t v0_1 = vld1q_s8(p0 + 16);
+ const int8x16_t v1_0 = vld1q_s8(p1);
+ const int8x16_t v1_1 = vld1q_s8(p1 + 16);
// 4-bit -> 8-bit
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+ const int8x16_t v0_0l = vandq_s8(v0_0, m4b);
+ const int8x16_t v0_1l = vandq_s8(v0_1, m4b);
+ const int8x16_t v1_0l = vandq_s8(v1_0, m4b);
+ const int8x16_t v1_1l = vandq_s8(v1_1, m4b);
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+ const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4);
+ const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4);
+ const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4);
+ const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4);
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
- const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const int8x16_t m4b = vdupq_n_s8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
- const uint8x16_t v0_0 = vld1q_u8(p0);
- const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
- const uint8x16_t v1_0 = vld1q_u8(p1);
- const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+ const int8x16_t v0_0 = vld1q_s8(p0);
+ const int8x16_t v0_1 = vld1q_s8(p0 + 16);
+ const int8x16_t v1_0 = vld1q_s8(p1);
+ const int8x16_t v1_1 = vld1q_s8(p1 + 16);
// 4-bit -> 8-bit
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+ const int8x16_t v0_0l = vandq_s8(v0_0, m4b);
+ const int8x16_t v1_0l = vandq_s8(v1_0, m4b);
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+ const int8x16_t v0_0h = vshrq_n_s8(v0_0, 4);
+ const int8x16_t v1_0h = vshrq_n_s8(v1_0, 4);
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+ const int8x16_t v0_1l = vandq_s8(v0_1, m4b);
+ const int8x16_t v1_1l = vandq_s8(v1_1, m4b);
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+ const int8x16_t v0_1h = vshrq_n_s8(v0_1, 4);
+ const int8x16_t v1_1h = vshrq_n_s8(v1_1, 4);
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);