]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Better 1.5 bit quantization (llama/5971)
authorKawrakow <redacted>
Mon, 11 Mar 2024 06:51:49 +0000 (07:51 +0100)
committerGeorgi Gerganov <redacted>
Fri, 15 Mar 2024 12:01:12 +0000 (14:01 +0200)
* Trying blocvks of 16 for IQ1_S - seems slightly better

* iq1s_blocks16: Adjust scale fudge factor to 1.125

* iq1s_blocks16: going to blocks of 32

with 2048 lattice points, so same bpw.
This is even better than blocks of 16.
Should I try blocks of 64? But to keep the same
bpw, when I go to 4096 lattice points, I need to
remove blocks alltogether and just have superblocks of
256 weights.

* iq1s_blocks16: Use 2*<x^2> as sigma2 in weight adjustment

* iq1s_blocks16: scalar and AVX2 dot products

* iq1s_blocks16: CUDA dot product

* iq1s_blocks16: Metal works, Neon does not

Metal works but TG is dog slow (35 t/s). PP is OKish (493 t/s).
Not seeing the bug in the Neon implementation for now.

* iq1s_blocks16: fixed Neon

* iq1s_blocks16: very slightly faster TG on Metal

Still pathetic at 37 t/s

* iq1s_blocks16: speedup Metal by packing codebook into uint32_t's

* Formatting

* iq1s_blocks16: uint32_t codebook is also better in CUDA

TG-128 is now 204 t/s up from 194 t/s.
PP-512 is 5890 t/s, so significantly better than other quants

* iq1s_blocks16: slightly faster Neon dot product

* iq1s_blocks16: faster AVX2 dot product

* iq1s_blocks16: adjust to ggml-common.h

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu
ggml-metal.metal
ggml-quants.c
ggml-quants.h

index c207ff87a281a552a9d8709cf4fbf27e5969a255..d2945d3c2048db29e09e85c0c3312d9c695e50ba 100644 (file)
@@ -565,8 +565,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
 #define QI1_S (QK_K / (4*QR1_S))
 typedef struct {
     half d;
-    uint8_t qs[QK_K/8];
-    uint8_t scales[QK_K/16];
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
 
@@ -1722,11 +1722,22 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
     const int il = tid/8; // 0...3
     const int ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const int i8 = 4*ib+il;
-    uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
-    const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
-    const float d = (float)x[i].d * (2*(h & 7) + 1);
-    for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
+    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = *((const int *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8))));
+    grid32[1] = __vsub4((grid32[0] >>  4) & 0x0f0f0f0f, 0x01010101);
+    grid32[0] = __vsub4(grid32[0] & 0x0f0f0f0f, 0x01010101);
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * q[j];
+    }
+#else
+    const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * ((grid[j] & 0xf) - 1);
+        y[j+4] = d * ((grid[j] >>  4) - 1);
+    }
+#endif
 #else
     assert(false);
 #endif
@@ -4538,44 +4549,33 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
 #endif
 }
 
-
 static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
 #if QK_K == 256
     const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
 
     const int ib32 = iqs;
-    int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
-    const uint8_t h1 = bq1->scales[2*ib32+0];
-    const uint8_t h2 = bq1->scales[2*ib32+1];
+    int sumi = 0;
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const int * q8 = (const int *)bq8_1[ib32].qs;
-    const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
-    const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
-    const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
-    const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
-    for (int j = 0; j < 2; ++j) {
-        sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
-        sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
-        sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
-        sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
+    for (int l = 0; l < 4; ++l) {
+        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
+        int grid0 = __vsub4(grid[0] & 0x0f0f0f0f, 0x01010101);
+        int grid1 = __vsub4((grid[0] >> 4) & 0x0f0f0f0f, 0x01010101);
+        sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
     }
 #else
     const int8_t   * q8 = bq8_1[ib32].qs;
-    const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
-    const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
-    const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
-    const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
-    for (int j = 0; j < 8; ++j) {
-        sumi1 += q8[j+ 0] * grid1[j];
-        sumi2 += q8[j+ 8] * grid2[j];
-        sumi3 += q8[j+16] * grid3[j];
-        sumi4 += q8[j+24] * grid4[j];
+    for (int l = 0; l < 4; ++l) {
+        const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
+        for (int j = 0; j < 4; ++j) {
+            sumi += q8[j] * ((grid[j] & 0xf) - 1) + q8[j+4] * ((grid[j] >>  4) - 1);
+        }
+        q8 += 8;
     }
 #endif
     const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
-    return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
-                sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
+    return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
 #else
     assert(false);
     return 0.f;
index 50185ae4dea09567aaad97ddc4f0f1a20a1002db..912822ee64bc3deafe16f08af55a253988701557 100644 (file)
@@ -2595,8 +2595,8 @@ typedef struct {
 
 typedef struct {
     half d;
-    uint8_t qs[QK_K/8];
-    uint8_t scales[QK_K/16];
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
 } block_iq1_s;
 
 // Non-linear quants
@@ -4338,48 +4338,53 @@ void kernel_mul_mv_iq1_s_f32_impl(
     device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
     device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
 
-    float yl[16];
+    float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg/2;
-    const int il = tiisg%2;
+    const int ix = tiisg;
 
-    device const float * y4 = y + 32 * ix + 16 * il;
+    device const float * y4 = y + 32 * ix;
 
-    for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
 
-        for (int i = 0; i < 16; ++i) {
+        float sumy = 0;
+        for (int i = 0; i < 32; ++i) {
             yl[i] = y4[i];
+            sumy += yl[i];
         }
 
         const int ibl = ib32 / (QK_K / 32);
         const int ib  = ib32 % (QK_K / 32);
 
         device const block_iq1_s * xr = x + ibl;
-        device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
-        device const uint8_t * sc = xr->scales + 2 * ib + il;
-        device const half    * dh = &xr->d;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint16_t * qh = xr->qh + ib;
+        device const half     * dh = &xr->d;
 
         for (int row = 0; row < N_DST; row++) {
 
-            constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
-            constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
 
-            float2 sum = {0};
-            for (int j = 0; j < 8; ++j) {
-                sum[0] += yl[j+ 0] * grid1[j];
-                sum[1] += yl[j+ 8] * grid2[j];
+            float sum = 0;
+            for (int j = 0; j < 4; ++j) {
+                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
             }
-            sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
+            sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1);
 
             dh += nb*sizeof(block_iq1_s)/2;
             qs += nb*sizeof(block_iq1_s);
-            sc += nb*sizeof(block_iq1_s);
+            qh += nb*sizeof(block_iq1_s)/2;
         }
 
-        y4 += 16 * 32;
+        y4 += 32 * 32;
     }
 
     for (int row = 0; row < N_DST; ++row) {
@@ -5066,16 +5071,19 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 &
 template <typename type4x4>
 void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
     // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
     const float d = xb->d;
-    device const uint8_t * qs = xb->qs + 2*il;
-    device const uint8_t * sc = xb->scales + il;
-    const float dl1 = d * (2*(sc[0] & 7) + 1);
-    const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
-    constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
-    constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4+0][i%4] = dl1 * grid1[i];
-        reg[i/4+2][i%4] = dl2 * grid2[i];
+    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint16_t * qh = xb->qh;
+    const float dl = d * (2*(qh[ib32] >> 12) + 1);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) - dl;
+        reg[1][i] = dl * (grid1[i] >>  4) - dl;
+        reg[2][i] = dl * (grid2[i] & 0xf) - dl;
+        reg[3][i] = dl * (grid2[i] >>  4) - dl;
     }
 }
 
index 42d8a5d8051440317c9332693e0a3dc2f2112ed0..f9a3d9fd229e186315d392502b7afa417e03a333 100644 (file)
@@ -3449,39 +3449,22 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
-    float db[4];
-    uint16_t idx[4];
-    //const int8_t * grid[4];
-
     for (int i = 0; i < nb; i++) {
 
         const float d = GGML_FP16_TO_FP32(x[i].d);
-        const uint8_t * sc = x[i].scales;
-        const uint8_t * qs = x[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
 
-        for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
-            idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
-            idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
-            idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
-            idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
-            //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
-            //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
-            //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
-            //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
-            db[0] = d * (2*(sc[0] & 7) + 1);
-            db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
-            db[2] = d * (2*(sc[1] & 7) + 1);
-            db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float dl = d * (2*(qh[ib] >> 12) + 1);
             for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
                 for (int j = 0; j < 8; ++j) {
-                    //y[j] = db[l] * grid[l][j];
-                    y[j] = db[l] * grid[j];
+                    y[j] = dl * grid[j];
                 }
                 y += 8;
             }
             qs += 4;
-            sc += 2;
         }
     }
 }
@@ -9587,113 +9570,72 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
 
     const int nb = n / QK_K;
 
-    // TODO: implement for QK_K = 64
-#if defined __ARM_NEON && QK_K == 256
-
-    const uint8x16_t m8 = vdupq_n_u8(0x08);
-    const uint8x16_t m7 = vdupq_n_u8(0x07);
-    const uint8x16_t m1 = vdupq_n_u8(0x01);
-    const int32x4_t vzero = vdupq_n_s32(0);
+#if defined __ARM_NEON
 
-    uint16_t gindex[8];
-    uint16x8x2_t vindex;
-    int8x16x4_t q1b;
+    ggml_int8x16x4_t q1b;
     ggml_int8x16x4_t q8b;
-    uint16x8x4_t scales;
-    int32x4x2_t sumi;
-    int32x4x2_t dotq;
 
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
 
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * sc = x[i].scales;
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
 
-        sumi.val[0] = sumi.val[1] = vzero;
+        int sumi1 = 0, sumi2 = 0;
 
-        for (int i128 = 0; i128 < QK_K/128; ++i128) {
-            const uint8x16_t ql = vld1q_u8(qs); qs += 16;
-            const uint8x8_t tm1 = vld1_u8 (sc); sc +=  8;
-            const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
-            const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
-            const uint8x16_t hbit = vandq_u8(qh, m8);
-            vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
-            vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
-            const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
-            scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
-            scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
 
-            for (int l = 0; l < 2; ++l) {
-                vst1q_u16(gindex+0, vindex.val[l]);
-                q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
-                q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
-                q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
-                q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
-                q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
+            qs += 8;
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
 
-                dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
-                dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
+
+            sumi1 += vaddvq_s32(p1) * (2*(qh[ib+0] >> 12) + 1);
+            sumi2 += vaddvq_s32(p2) * (2*(qh[ib+1] >> 12) + 1);
 
-                sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
-                sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
-            }
         }
 
-        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
+        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2);
     }
 
     *s = sumf;
 
-    // TODO: implement for QK_K = 64
-#elif defined __AVX2__ && QK_K == 256
-
-    const __m128i m8 = _mm_set1_epi8(0x08);
-    const __m128i m7 = _mm_set1_epi8(0x07);
-    const __m128i m1 = _mm_set1_epi8(0x01);
-    const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
-    const __m128i shuffle_s[4] = {
-        _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
-        _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
-        _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
-        _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
-    };
-
-    uint64_t aux64;
-
-    typedef union m256i_uint16 {
-        __m256i reg;
-        uint16_t s[16];
-    } m256i_uint16_t;
-
-    m256i_uint16_t v_gindex;
+#elif defined __AVX2__
 
     __m256 accum = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
 
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * sc = x[i].scales;
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
 
         __m256i sumi = _mm256_setzero_si256();
-        for (int i128 = 0; i128 < QK_K/128; ++i128) {
-            const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
-            memcpy(&aux64, sc, 8); sc += 8;
-            const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
-            const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
-            v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
-            const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+                                                    iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+                                                    iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            qs += 8;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
 
-            for (int i32 = 0; i32 < 4; ++i32) {
-                const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-                const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
-                                                      iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
-                const __m256i dot = mul_add_epi8(q1b, q8b);
-                const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
-                const __m256i p   = _mm256_madd_epi16(s16, dot);
-                sumi = _mm256_add_epi32(sumi, p);
-            }
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*(qh[ib+0] >> 12) + 1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*(qh[ib+1] >> 12) + 1));
 
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
         }
 
         accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
@@ -9704,35 +9646,26 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
 
 #else
 
-    int db[4];
-    uint16_t idx[4];
-
     float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
+    for (int i = 0; i < nb; i++) {
 
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * sc = x[i].scales;
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
 
         int sumi = 0;
-        for (int i32 = 0; i32 < QK_K/32; ++i32) {
-            idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
-            idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
-            idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
-            idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
-            db[0] = (2*(sc[0] & 7) + 1);
-            db[1] = (2*((sc[0] >> 4) & 7) + 1);
-            db[2] = (2*(sc[1] & 7) + 1);
-            db[3] = (2*((sc[1] >> 4) & 7) + 1);
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const int ls = 2*(qh[ib] >> 12) + 1;
+            int lsum = 0;
             for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
-                int suml = 0;
-                for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
-                sumi += db[l] * suml;
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+                for (int j = 0; j < 8; ++j) {
+                    lsum += q8[j] * grid[j];
+                }
                 q8 += 8;
             }
+            sumi += ls * lsum;
             qs += 4;
-            sc += 2;
         }
 
         sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
@@ -9996,7 +9929,7 @@ static inline int iq2_grid_size(enum ggml_type type) {
     GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
     return type == GGML_TYPE_IQ2_XXS ? 256 :
            type == GGML_TYPE_IQ2_XS  ? 512 :
-           type == GGML_TYPE_IQ1_S   ? 512 : 1024;
+           type == GGML_TYPE_IQ1_S   ? NGRID_IQ1S : 1024;
 }
 
 static int iq2_compare_func(const void * left, const void * right) {
@@ -10063,39 +9996,135 @@ void iq2xs_init_impl(enum ggml_type type) {
         40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
         42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
     };
-    static const uint16_t kgrid_1bit_512[512] = {
-           10,    33,    41,    85,   132,   134,   160,   162,   277,   337,   340,   345,   357,   405,   516,   545,
-          553,   598,   641,   650,   681,  1042,  1044,  1097,  1169,  1176,  1320,  1345,  1365,  1378,  1434,  1444,
-         1545,  1617,  1642,  1685,  2053,  2080,  2089,  2133,  2176,  2182,  2208,  2214,  2306,  2384,  2393,  2440,
-         2453,  2581,  2664,  2690,  2721,  4117,  4161,  4182,  4184,  4261,  4357,  4369,  4372,  4377,  4390,  4422,
-         4432,  4437,  4449,  4457,  4485,  4497,  4505,  4629,  4677,  4696,  4774,  5205,  5217,  5225,  5386,  5397,
-         5409,  5445,  5457,  5460,  5461,  5462,  5465,  5472,  5477,  5525,  5545,  5650,  5668,  5717,  5729,  5769,
-         5777,  6212,  6234,  6244,  6293,  6424,  6482,  6485,  6502,  6505,  6529,  6538,  6565,  6656,  6682,  6788,
-         6806,  6820,  8218,  8224,  8226,  8232,  8277,  8326,  8354,  8469,  8521,  8530,  8549,  8596,  8737,  8794,
-         9221,  9253,  9348,  9369,  9380,  9474,  9557,  9633,  9732,  9753,  9793,  9830,  9862,  9880, 10240, 10272,
-        10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
-        16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
-        17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
-        18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
-        20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
-        20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
-        21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
-        21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
-        21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
-        22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
-        23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
-        25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
-        25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
-        26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
-        32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
-        33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
-        34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
-        37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
-        38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
-        38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
-        39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
-        41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
-        42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
+    static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {
+            0,     2,     5,     8,    10,    17,    21,    32,    34,    40,    42,    69,    81,    84,    86,   101,
+          128,   130,   136,   138,   149,   160,   162,   168,   170,   260,   261,   273,   276,   278,   281,   282,
+          293,   321,   326,   329,   338,   341,   346,   353,   356,   358,   360,   389,   401,   404,   406,   421,
+          512,   514,   520,   522,   533,   544,   546,   552,   554,   581,   593,   601,   612,   617,   640,   642,
+          648,   650,   657,   661,   665,   672,   674,   680,   682,  1041,  1044,  1046,  1061,  1089,  1097,  1109,
+         1114,  1124,  1125,  1169,  1177,  1189,  1281,  1284,  1285,  1286,  1301,  1304,  1306,  1321,  1344,  1349,
+         1354,  1360,  1361,  1364,  1365,  1366,  1369,  1376,  1378,  1381,  1384,  1386,  1409,  1425,  1429,  1432,
+         1434,  1441,  1444,  1445,  1446,  1449,  1556,  1561,  1601,  1604,  1616,  1618,  1621,  1624,  1632,  1633,
+         1638,  1641,  1669,  1681,  1684,  1689,  2048,  2050,  2056,  2058,  2069,  2080,  2082,  2088,  2090,  2117,
+         2129,  2134,  2149,  2176,  2178,  2184,  2186,  2197,  2208,  2210,  2216,  2218,  2309,  2321,  2324,  2329,
+         2340,  2341,  2369,  2384,  2385,  2389,  2401,  2404,  2409,  2449,  2452,  2454,  2457,  2469,  2560,  2562,
+         2568,  2570,  2581,  2592,  2594,  2600,  2602,  2629,  2641,  2649,  2657,  2661,  2688,  2690,  2693,  2696,
+         2698,  2709,  2720,  2722,  2728,  2730,  4112,  4113,  4116,  4121,  4132,  4133,  4161,  4164,  4176,  4181,
+         4184,  4193,  4196,  4197,  4201,  4241,  4244,  4246,  4257,  4261,  4353,  4356,  4358,  4361,  4368,  4370,
+         4373,  4376,  4385,  4388,  4393,  4421,  4426,  4432,  4433,  4434,  4436,  4437,  4438,  4441,  4448,  4453,
+         4484,  4498,  4501,  4513,  4516,  4625,  4628,  4630,  4645,  4672,  4678,  4681,  4690,  4693,  4696,  4698,
+         4708,  4710,  4741,  4753,  4756,  4758,  4773,  5121,  5126,  5129,  5140,  5141,  5144,  5145,  5153,  5158,
+         5185,  5189,  5190,  5192,  5194,  5201,  5204,  5205,  5206,  5209,  5218,  5221,  5224,  5252,  5257,  5264,
+         5268,  5269,  5272,  5273,  5274,  5281,  5284,  5285,  5289,  5378,  5381,  5386,  5393,  5396,  5397,  5398,
+         5401,  5408,  5410,  5413,  5416,  5418,  5441,  5444,  5445,  5446,  5457,  5458,  5460,  5461,  5462,  5465,
+         5466,  5473,  5476,  5477,  5478,  5481,  5504,  5506,  5508,  5509,  5512,  5514,  5520,  5521,  5524,  5525,
+         5526,  5529,  5530,  5536,  5538,  5541,  5633,  5636,  5637,  5638,  5653,  5654,  5656,  5658,  5665,  5670,
+         5696,  5698,  5700,  5701,  5704,  5706,  5713,  5717,  5718,  5720,  5721,  5729,  5732,  5733,  5736,  5737,
+         5738,  5766,  5770,  5778,  5781,  5796,  5801,  6161,  6166,  6181,  6209,  6212,  6214,  6217,  6224,  6229,
+         6232,  6234,  6240,  6241,  6244,  6246,  6249,  6277,  6289,  6292,  6309,  6416,  6418,  6421,  6426,  6433,
+         6437,  6466,  6468,  6469,  6472,  6481,  6484,  6485,  6486,  6489,  6490,  6496,  6501,  6506,  6537,  6545,
+         6546,  6549,  6552,  6561,  6566,  6569,  6665,  6678,  6692,  6694,  6724,  6726,  6729,  6736,  6738,  6741,
+         6744,  6753,  6758,  6761,  6789,  6801,  6806,  6810,  8192,  8194,  8200,  8202,  8213,  8224,  8226,  8229,
+         8232,  8234,  8261,  8273,  8281,  8289,  8293,  8320,  8322,  8328,  8330,  8341,  8352,  8354,  8357,  8360,
+         8362,  8453,  8465,  8468,  8473,  8485,  8514,  8516,  8521,  8533,  8536,  8538,  8545,  8548,  8549,  8550,
+         8581,  8592,  8598,  8601,  8613,  8705,  8712,  8714,  8721,  8725,  8736,  8738,  8744,  8746,  8773,  8785,
+         8790,  8793,  8805,  8833,  8840,  8842,  8849,  8853,  8864,  8866,  8872,  8874,  9221,  9236,  9238,  9241,
+         9253,  9284,  9285,  9286,  9289,  9298,  9301,  9304,  9306,  9318,  9349,  9361,  9364,  9369,  9377,  9381,
+         9481,  9493,  9505,  9513,  9536,  9541,  9544,  9553,  9556,  9557,  9561,  9570,  9573,  9576,  9609,  9616,
+         9620,  9621,  9624,  9626,  9633,  9636,  9638,  9641,  9733,  9744,  9746,  9753,  9765,  9793,  9801,  9813,
+         9824,  9825,  9833,  9860,  9862,  9872,  9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
+        10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
+        10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
+        10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
+        10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
+        16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
+        16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
+        16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
+        16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
+        17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
+        17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
+        17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
+        17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
+        17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
+        18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
+        18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
+        18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
+        18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
+        19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
+        20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
+        20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
+        20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
+        20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
+        20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
+        21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
+        21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
+        21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
+        21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
+        21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
+        21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
+        21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
+        21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
+        22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
+        22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
+        22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
+        22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
+        22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
+        22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
+        22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
+        23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
+        23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
+        24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
+        24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
+        24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
+        25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
+        25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
+        25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
+        25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
+        26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
+        26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
+        26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
+        26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
+        26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
+        27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
+        27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
+        32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
+        33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
+        33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
+        33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
+        33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
+        34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
+        34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
+        34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
+        34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
+        35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
+        35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
+        35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
+        36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
+        37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
+        37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
+        37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
+        37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
+        37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
+        38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
+        38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
+        38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
+        38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
+        38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
+        39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
+        39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
+        39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
+        39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
+        41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
+        41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
+        41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
+        41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
+        42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
+        42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
+        42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
+        42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
+        43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
+        43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
+        43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
     };
     static const uint16_t kgrid_2bit_1024[1024] = {
             0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
@@ -10169,7 +10198,7 @@ void iq2xs_init_impl(enum ggml_type type) {
     const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
     const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
                              type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 :
-                             type == GGML_TYPE_IQ1_S   ? kgrid_1bit_512 : kgrid_2bit_1024;
+                             type == GGML_TYPE_IQ1_S   ? kgrid_1bit_2048 : kgrid_2bit_1024;
     uint64_t * kgrid_q2xs;
     int      * kmap_q2xs;
     uint16_t * kneighbors_q2xs;
@@ -11408,12 +11437,70 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
+static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_score = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = (pg[i] - 3)/2;
+            float w = weight[i];
+            float diff = scale*q - xval[i];
+            d2 += w*diff*diff;
+        }
+        if (d2 < best_score) {
+            best_score = d2;
+            grid_index = neighbours[j];
+        }
+    }
+    if (grid_index < 0) {
+        for (int i = 0; i < ngrid; ++i) {
+            const int8_t * grid_i = (const int8_t *)(grid + i);
+            float d2 = 0;
+            for (int j = 0; j < 8; ++j) {
+                float w = weight[j];
+                float q = (grid_i[j] - 3)/2;
+                float diff = scale*q - xval[i];
+                d2 += w*diff*diff;
+            }
+            if (d2 < best_score) {
+                best_score = d2;
+                grid_index = i;
+            }
+        }
+    }
+    if (grid_index < 0) {
+        printf("Oops, did not find grid point\n");
+        printf("Have %d neighbours\n", num_neighbors);
+        for (int j = 1; j <= num_neighbors; ++j) {
+            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+            float sumqx = 0, sumq2 = 0;
+            for (int i = 0; i < 8; ++i) {
+                float q = (pg[i] - 3)/2;
+                float w = weight[i];
+                sumqx += w*q*xval[i];
+                sumq2 += w*q*q;
+            }
+            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
 static int iq1_sort_helper(const void * left, const void * right) {
     const float * l = left;
     const float * r = right;
     return *l < *r ? -1 : *l > *r ? 1 : 0;
 }
 
+#define IQ1S_BLOCK_SIZE 32
 static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
@@ -11432,37 +11519,37 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
 
     block_iq1_s * y = vy;
 
-    float  scales[QK_K/8];
-    float  weight[8];
-    int8_t L[8];
-    float  sumx[9];
-    float  sumw[9];
-    float  pairs[16];
+    float  scales[QK_K/IQ1S_BLOCK_SIZE];
+    float  weight[IQ1S_BLOCK_SIZE];
+    int8_t L[IQ1S_BLOCK_SIZE];
+    float  sumx[IQ1S_BLOCK_SIZE+1];
+    float  sumw[IQ1S_BLOCK_SIZE+1];
+    float  pairs[2*IQ1S_BLOCK_SIZE];
     int * idx = (int *)(pairs + 1);
-    uint8_t hbit[QK_K/8];
+    uint16_t index[IQ1S_BLOCK_SIZE/8];
 
     for (int ibl = 0; ibl < nbl; ++ibl) {
 
         y[ibl].d = GGML_FP32_TO_FP16(0.f);
         memset(y[ibl].qs, 0, QK_K/8);
-        memset(y[ibl].scales, 0, QK_K/16);
+        memset(y[ibl].qh, 0, QK_K/16);
 
         float max_scale = 0;
 
         const float * xbl = x + QK_K*ibl;
         float sumx2 = 0;
         for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
-        float sigma2 = sumx2/QK_K;
+        float sigma2 = 2*sumx2/QK_K;
 
-        for (int ib = 0; ib < QK_K/8; ++ib) {
-            const float * xb = xbl + 8*ib;
-            const float * qw = quant_weights + QK_K*ibl + 8*ib;
-            for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+        for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
+            const float * xb = xbl + IQ1S_BLOCK_SIZE*ib;
+            const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib;
+            for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
             float max = fabsf(xb[0]);
-            for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
+            for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i]));
             if (!max) {
                 scales[ib] = 0;
-                memset(L, 1, 8);
+                memset(L, 1, IQ1S_BLOCK_SIZE);
                 continue;
             }
             // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
@@ -11471,14 +11558,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
             // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
             // for each possible and score for each split.
-            for (int j = 0; j < 8; ++j) {
+            for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
                 pairs[2*j] = xb[j];
                 idx[2*j] = j;
             }
-            qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
+            qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper);
             {
                 sumx[0] = sumw[0] = 0;
-                for (int j = 0; j < 8; ++j) {
+                for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
                     int i = idx[2*j];
                     sumx[j+1] = sumx[j] + weight[i]*xb[i];
                     sumw[j+1] = sumw[j] + weight[i];
@@ -11486,10 +11573,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             }
             float best_score = 0, scale = max;
             int besti1 = 0, besti2 = 0;
-            for (int i1 = 0; i1 <= 8; ++i1) {
-                for (int i2 = i1; i2 <= 8; ++i2) {
-                    float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
-                    float sumq2 =  (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
+            for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
+                for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
+                    float sumqx = -(sumx[i1] - sumx[0]) + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2]);
+                    float sumq2 =  (sumw[i1] - sumw[0]) + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2]);
                     if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
                         scale = sumqx/sumq2; best_score = scale*sumqx;
                         besti1 = i1; besti2 = i2;
@@ -11498,23 +11585,43 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             }
             for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
             for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
-            for (int j = besti2; j <      8; ++j) L[idx[2*j]] = 2;
+            for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
             if (scale < 0) {
-                for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
+                for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
                 scale = -scale;
             }
-            // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
-            // grid point that minimizes SSD.
-            uint16_t u = 0;
-            for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
-            int grid_index = kmap_q2xs[u];
-            if (grid_index < 0) {
-                const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
-                grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
-                GGML_ASSERT(grid_index >= 0);
-            }
-            y[ibl].qs[ib] = grid_index & 255;
-            hbit[ib] = grid_index >> 8;
+            bool all_on_grid = true;
+            for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
+                uint16_t u = 0;
+                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    all_on_grid = false;
+                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S);
+                    GGML_ASSERT(grid_index >= 0);
+                }
+                index[k] = grid_index;
+            }
+            if (!all_on_grid) {
+                float sumqx = 0, sumq2 = 0;
+                for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+                    for (int j = 0; j < 8; ++j) {
+                        float w = weight[8*k + j];
+                        float q = (pg[j] - 3)/2;
+                        sumqx += w*q*xb[8*k+j];
+                        sumq2 += w*q*q;
+                    }
+                }
+                if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
+            }
+            uint16_t h = 0;
+            for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
+                y[ibl].qs[(IQ1S_BLOCK_SIZE/8)*ib + k] = index[k] & 255;
+                h |= (index[k] >> 8) << 3*k;
+            }
+            y[ibl].qh[ib] = h;
             GGML_ASSERT(scale >= 0);
             scales[ib] = scale;
             max_scale = MAX(max_scale, scale);
@@ -11525,14 +11632,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             continue;
         }
 
-        float d = max_scale/15;
-        y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
         float id = 1/d;
-        for (int ib = 0; ib < QK_K/8; ++ib) {
+        for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
             int l = nearest_int(0.5f*(id*scales[ib]-1));
-            l = MAX(0, MIN(7, l));
-            if (hbit[ib]) l |= 8;
-            y[ibl].scales[ib/2] |= (l << 4*(ib%2));
+            l = MAX(0, MIN(15, l));
+            y[ibl].qh[ib] |= (l << 12);
         }
     }
 }
index 47dd52856422abbde4b12e88230ed3b9937daa89..74aabf4156385539bfe5a6845acb13fd8ff6e30b 100644 (file)
@@ -217,8 +217,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
 
 typedef struct {
     ggml_fp16_t d;
-    uint8_t qs[QK_K/8];
-    uint8_t scales[QK_K/16];
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");