]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
IQ3_S: a much better alternative to Q3_K (llama/5676)
authorKawrakow <redacted>
Sat, 24 Feb 2024 14:23:52 +0000 (16:23 +0200)
committerGeorgi Gerganov <redacted>
Sun, 25 Feb 2024 17:58:46 +0000 (19:58 +0200)
* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* Resurrecting iq3_xs

After all the experimentation, nothing was better than this.

* Minor PPL improvement via a block scale fudge factor

* Minor improvement via 3 neighbours

* iq3_xs: working scalar and AVX2 dot products

* iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s)

* iq3_xs: working Metal implementation

* Adding IQ3_M - IQ3_XS mix with mostly Q4_K

* iiq3_xs: a 3.4375 bpw variant

* iq3_xs: make CUDA work for new version

* iq3_xs: make scalar and AVX2 work for new version

* iq3_s: make ARM_NEON work with new version

* iq3_xs: make new version work on metal

Performance is very similar to Q3_K_S

* iq3_xs: tiny Metal speed improvement

* iq3_xs: tiny Metal speed improvement

* Fix stupid warning

* Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS

* iq3_xs: rename to iq3_s

* iq3_s: make tests pass

* Move Q3_K_XS mix to 3.25 bpw

* Attempt to fix failing tests

* Another attempt to fix the Windows builds

* Attempt to fix ROCm

* ROCm again

* iq3_s: partial fix for QK_K = 64

* iq3_s: make it work on metal for QK_K = 64

Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.

* Will this fix ROCm?

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index c1b7a729dce6ce336cb8eb01963d657d2790c44a..42410e8b6d5f8f05b9a6cbd8175e65e4ff9cfe1e 100644 (file)
 #endif
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
     const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
     const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
@@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
     return __vsubss4(a, b);
 }
 
+static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+    }
+    return c;
+}
+
 static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -518,6 +531,17 @@ typedef struct {
 } block_iq3_xxs;
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
+#define QR3_XS 8
+#define QI3_XS (QK_K / (4*QR3_XS))
+typedef struct {
+    half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[QK_K/64];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
+
 #define QR1_S 8
 #define QI1_S (QK_K / (4*QR1_S))
 typedef struct {
@@ -1700,6 +1724,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+static const __device__ uint32_t iq3xs_grid[512] = {
+    0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
+    0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
+    0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
+    0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
+    0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
+    0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
+    0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
+    0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
+    0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
+    0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
+    0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
+    0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
+    0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
+    0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
+    0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
+    0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
+    0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
+    0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
+    0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
+    0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
+    0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
+    0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
+    0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
+    0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
+    0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
+    0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
+    0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
+    0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
+    0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
+    0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
+    0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
+    0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
+    0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
+    0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
+    0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
+    0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
+    0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
+    0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
+    0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
+    0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
+    0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
+    0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
+    0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
+    0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
+    0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
+    0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
+    0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
+    0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
+    0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
+    0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
+    0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
+    0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
+    0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
+    0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
+    0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
+    0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
+    0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
+    0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
+    0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
+    0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
+    0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
+    0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
+    0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
+    0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
+};
+
+
 static const __device__ uint64_t iq1s_grid[512] = {
     0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
     0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@@ -1973,6 +2065,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
 
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq3_s * x = (const block_iq3_s *) vx;
+
+    const int tid = threadIdx.x;
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * qs = x[i].qs + 8*ib;
+    const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+    const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
+    const uint8_t signs = x[i].signs[4*ib + il];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
 template<typename dst_t>
 static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
@@ -4717,6 +4835,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #endif
 }
 
+// TODO: don't use lookup table for signs
+static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+
+    const int ib32 = iqs;
+    const uint8_t  * qs = bq2->qs + 8*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
+        const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
+        uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
+        uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >>  4) * 0x01010101) & 0x08040201, 0x08040201);
+        const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
+        const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
+        sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
+        sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
+        q8 += 8;
+    }
+    const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
+    return d * sumi;
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+
 static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
 #if QK_K == 256
@@ -6849,6 +7002,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
     dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template<typename dst_t>
 static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
@@ -6904,6 +7063,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cuda<float>;
         default:
@@ -6943,6 +7104,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half>;
         default:
@@ -8688,6 +8851,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         default:
             GGML_ASSERT(false);
@@ -8713,6 +8877,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
@@ -8818,6 +8983,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
             mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
                 (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -11541,7 +11710,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 }
                 ggml_type a_type = a->type;
                 if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL) {
+                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index 8d2e32c6d05bb19bfa371915bea174c0f67463a3..6bf0b190926e34c4532c82d75293c33dc13f0c83 100644 (file)
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -105,6 +107,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -122,6 +125,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -139,6 +143,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -452,6 +457,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,            get_rows_iq3_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
@@ -476,6 +482,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,          mul_mv_iq3_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
@@ -496,6 +503,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,       mul_mv_id_iq3_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
@@ -513,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,          mul_mm_iq3_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
@@ -530,6 +539,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,       mul_mm_id_iq3_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
@@ -1347,6 +1357,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+                                case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1483,6 +1494,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ3_S:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
+                                    } break;
                                 case GGML_TYPE_IQ1_S:
                                     {
                                         nth0 = 4;
@@ -1537,8 +1554,8 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src0t == GGML_TYPE_IQ3_XXS) {
-                                const int mem_size = 256*4+128;
+                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1640,6 +1657,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+                                case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1779,6 +1797,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ3_S:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
+                                    } break;
                                 case GGML_TYPE_IQ1_S:
                                     {
                                         nth0 = 4;
@@ -1849,8 +1873,8 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ3_XXS) {
-                                const int mem_size = 256*4+128;
+                            else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
+                                const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1900,6 +1924,7 @@ static bool ggml_metal_graph_compute(
                             case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
                             case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
                             case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+                            case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
                             case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
                             case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
index c223a981c246a39286647847ef5a3eab81fef9f8..b3bf405391d3e347d30c79e9918e2afdb8fbd6a2 100644 (file)
@@ -2525,6 +2525,20 @@ typedef struct {
 } block_iq3_xxs;
 // 98 bytes / block for QK_K = 256, so 3.0625 bpw
 
+// 3.4375 bpw
+#if QK_K == 64
+#define IQ3S_N_SCALE 2
+#else
+#define IQ3S_N_SCALE QK_K/64
+#endif
+typedef struct {
+    half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+
 typedef struct {
     half d;
     uint8_t qs[QK_K/8];
@@ -3795,6 +3809,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+constexpr constant static uint32_t iq3xs_grid[512] = {
+    0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
+    0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
+    0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
+    0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
+    0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
+    0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
+    0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
+    0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
+    0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
+    0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
+    0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
+    0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
+    0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
+    0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
+    0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
+    0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
+    0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
+    0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
+    0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
+    0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
+    0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
+    0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
+    0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
+    0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
+    0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
+    0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
+    0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
+    0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
+    0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
+    0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
+    0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
+    0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
+    0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
+    0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
+    0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
+    0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
+    0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
+    0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
+    0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
+    0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
+    0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
+    0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
+    0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
+    0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
+    0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
+    0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
+    0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
+    0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
+    0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
+    0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
+    0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
+    0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
+    0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
+    0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
+    0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
+    0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
+    0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
+    0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
+    0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
+    0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
+    0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
+    0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
+    0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
+    0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
+};
+
 #define NGRID_IQ1S 512
 constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
     0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -4361,6 +4442,136 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
     kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
+void kernel_mul_mv_iq3_s_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 8 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + (ib/2);
+        device const uint8_t * signs = xr->signs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += nb*sizeof(block_iq3_s)/2;
+            qs  += nb*sizeof(block_iq3_s);
+            qh  += nb*sizeof(block_iq3_s);
+            sc  += nb*sizeof(block_iq3_s);
+            signs += nb*sizeof(block_iq3_s);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
 void kernel_mul_mv_iq1_s_f32_impl(
         device const  void * src0,
         device const float * src1,
@@ -4952,6 +5163,31 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
     }
 }
 
+template <typename type4x4>
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 8*ib32;
+    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+    }
+    grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+    grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+    }
+}
+
 template <typename type4x4>
 void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
     // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -5525,6 +5761,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2, dequantize_iq4_nl>;
 
@@ -5566,6 +5803,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2, dequantize_iq4_nl>;
 
@@ -5619,6 +5857,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2, dequantize_iq4_nl>;
 
@@ -6589,6 +6828,71 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
         sgitg);
 }
 
+[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
+kernel void kernel_mul_mv_id_iq3_s_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq3_s_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_values,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
 [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
 kernel void kernel_mul_mv_id_iq1_s_f32(
         device const    char * ids,
index b15977f53e2f3f279352f8fa97145a36ea266550..5c5f2ce1b9b87b363453985cdc68428ab415c807 100644 (file)
@@ -3505,6 +3505,73 @@ static const uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+static const uint32_t iq3xs_grid[512] = {
+    0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
+    0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
+    0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
+    0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
+    0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
+    0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
+    0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
+    0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
+    0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
+    0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
+    0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
+    0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
+    0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
+    0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
+    0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
+    0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
+    0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
+    0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
+    0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
+    0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
+    0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
+    0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
+    0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
+    0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
+    0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
+    0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
+    0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
+    0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
+    0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
+    0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
+    0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
+    0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
+    0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
+    0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
+    0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
+    0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
+    0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
+    0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
+    0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
+    0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
+    0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
+    0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
+    0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
+    0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
+    0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
+    0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
+    0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
+    0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
+    0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
+    0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
+    0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
+    0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
+    0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
+    0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
+    0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
+    0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
+    0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
+    0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
+    0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
+    0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
+    0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
+    0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
+    0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
+    0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
+};
+
 #define NGRID_IQ2XXS 512
 static const  uint64_t iq1s_grid[NGRID_IQ2XXS] = {
     0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -3736,6 +3803,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
     }
 }
 
+// ====================== 3.3125 bpw (de)-quantization
+
+void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+        const uint8_t * signs = x[i].signs;
+
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
+            const float db2 = d * (0.5f + (x[i].scales[ib32/2] >>  4)) * 0.5f;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qs += 8;
+            signs += 4;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qh += 2;
+            qs += 8;
+            signs += 4;
+        }
+    }
+}
+
 // ====================== 1.5625 bpw (de)-quantization
 
 void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
@@ -8806,6 +8916,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #endif
 
+#if defined (__AVX2__) || defined (__ARM_NEON)
 static const int8_t keven_signs_q2xs[1024] = {
      1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
      1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
@@ -8840,6 +8951,7 @@ static const int8_t keven_signs_q2xs[1024] = {
      1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
      1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
 };
+#endif
 
 void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
@@ -9327,6 +9439,202 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 #endif
 }
 
+void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
+    const uint8x16_t   mask2 = vld1q_u8(k_mask2);
+
+    uint8x16x2_t vs;
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t   * restrict q8 = y[i].qs;
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
+                                          iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
+            const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
+                                          iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
+            const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
+                                          iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
+            const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
+                                          iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
+            qs += 16;
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
+            vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
+            q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
+            vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            signs += 4;
+
+            q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
+            q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+            sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
+            sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >>  4));
+        }
+        sumf += d*(sumi1 + sumi2);
+    }
+    *s = 0.25f * sumf;
+
+#elif defined(__AVX2__)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
+                                                  iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
+                                                  iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
+                                                  iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
+                                                  iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
+                                                  iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
+                                                  iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
+                                                  iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
+            qs += 8;
+            const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
+                                                  iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
+                                                  iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
+                                                  iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
+                                                  iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
+                                                  iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
+                                                  iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
+                                                  iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
+            qs += 8;
+
+            __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+            aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint8_t * restrict signs = x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
+            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls1;
+            sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls2;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.25f * sumf;
+#endif
+}
+
+
 #ifdef __AVX2__
 static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
     const __m256i ax = _mm256_sign_epi8(x, x);
@@ -9523,6 +9831,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
     float sumf = 0;
 
     for (int ib = 0; ib < nb; ib += 2) {
+
         q4bits.val[0] = vld1q_u8(x[ib+0].qs);
         q4bits.val[1] = vld1q_u8(x[ib+1].qs);
         q8b.val[0]    = vld1q_s8(y[ib+0].qs);
@@ -10239,14 +10548,15 @@ typedef struct {
     uint16_t * neighbours;
 } iq3_entry_t;
 
-static iq3_entry_t iq3_data[1] = {
+static iq3_entry_t iq3_data[2] = {
+    {NULL, NULL, NULL},
     {NULL, NULL, NULL},
 };
 
 static inline int iq3_data_index(int grid_size) {
     (void)grid_size;
-    GGML_ASSERT(grid_size == 256);
-    return 0;
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
+    return grid_size == 256 ? 0 : 1;
 }
 
 static int iq3_compare_func(const void * left, const void * right) {
@@ -10278,9 +10588,44 @@ void iq3xs_init_impl(int grid_size) {
          3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,
          3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,
     };
+    static const uint16_t kgrid_512[512] = {
+            0,     1,     2,     5,     7,     8,     9,    10,    12,    14,    16,    17,    21,    27,    32,    34,
+           37,    39,    41,    43,    48,    50,    57,    60,    63,    64,    65,    66,    68,    72,    73,    77,
+           80,    83,    87,    89,    93,   100,   113,   117,   122,   128,   129,   133,   135,   136,   139,   142,
+          145,   149,   152,   156,   162,   165,   167,   169,   171,   184,   187,   195,   201,   205,   208,   210,
+          217,   219,   222,   228,   232,   234,   247,   249,   253,   256,   267,   271,   273,   276,   282,   288,
+          291,   297,   312,   322,   324,   336,   338,   342,   347,   353,   357,   359,   374,   379,   390,   393,
+          395,   409,   426,   441,   448,   450,   452,   464,   466,   470,   475,   488,   492,   512,   513,   514,
+          516,   520,   521,   523,   525,   527,   528,   530,   537,   540,   542,   556,   558,   561,   570,   576,
+          577,   579,   582,   584,   588,   593,   600,   603,   609,   616,   618,   632,   638,   640,   650,   653,
+          655,   656,   660,   666,   672,   675,   685,   688,   698,   705,   708,   711,   712,   715,   721,   727,
+          728,   732,   737,   754,   760,   771,   773,   778,   780,   793,   795,   802,   806,   808,   812,   833,
+          840,   843,   849,   856,   858,   873,   912,   916,   919,   932,   934,   961,   963,   968,   970,   977,
+          989,   993,  1010,  1016,  1024,  1025,  1027,  1029,  1031,  1032,  1034,  1036,  1038,  1041,  1043,  1047,
+         1048,  1050,  1057,  1059,  1061,  1064,  1066,  1079,  1080,  1083,  1085,  1088,  1090,  1096,  1099,  1103,
+         1106,  1109,  1113,  1116,  1122,  1129,  1153,  1156,  1159,  1169,  1171,  1176,  1183,  1185,  1195,  1199,
+         1209,  1212,  1216,  1218,  1221,  1225,  1234,  1236,  1241,  1243,  1250,  1256,  1270,  1281,  1287,  1296,
+         1299,  1306,  1309,  1313,  1338,  1341,  1348,  1353,  1362,  1375,  1376,  1387,  1400,  1408,  1410,  1415,
+         1425,  1453,  1457,  1477,  1481,  1494,  1496,  1507,  1512,  1538,  1545,  1547,  1549,  1551,  1554,  1561,
+         1563,  1565,  1570,  1572,  1575,  1577,  1587,  1593,  1601,  1603,  1605,  1612,  1617,  1619,  1632,  1648,
+         1658,  1662,  1664,  1674,  1680,  1690,  1692,  1704,  1729,  1736,  1740,  1745,  1747,  1751,  1752,  1761,
+         1763,  1767,  1773,  1787,  1795,  1801,  1806,  1810,  1817,  1834,  1840,  1844,  1857,  1864,  1866,  1877,
+         1882,  1892,  1902,  1915,  1934,  1953,  1985,  1987,  2000,  2002,  2013,  2048,  2052,  2058,  2064,  2068,
+         2071,  2074,  2081,  2088,  2104,  2114,  2119,  2121,  2123,  2130,  2136,  2141,  2147,  2153,  2157,  2177,
+         2179,  2184,  2189,  2193,  2203,  2208,  2223,  2226,  2232,  2244,  2249,  2251,  2256,  2258,  2265,  2269,
+         2304,  2306,  2324,  2335,  2336,  2361,  2373,  2375,  2385,  2418,  2443,  2460,  2480,  2504,  2509,  2520,
+         2531,  2537,  2562,  2568,  2572,  2578,  2592,  2596,  2599,  2602,  2614,  2620,  2625,  2627,  2629,  2634,
+         2641,  2650,  2682,  2688,  2697,  2707,  2712,  2718,  2731,  2754,  2759,  2760,  2775,  2788,  2793,  2805,
+         2811,  2817,  2820,  2832,  2842,  2854,  2890,  2902,  2921,  2923,  2978,  3010,  3012,  3026,  3081,  3083,
+         3085,  3097,  3099,  3120,  3136,  3152,  3159,  3188,  3210,  3228,  3234,  3245,  3250,  3256,  3264,  3276,
+         3281,  3296,  3349,  3363,  3378,  3392,  3395,  3420,  3440,  3461,  3488,  3529,  3531,  3584,  3588,  3591,
+         3600,  3602,  3614,  3616,  3628,  3634,  3650,  3657,  3668,  3683,  3685,  3713,  3716,  3720,  3726,  3729,
+         3736,  3753,  3778,  3802,  3805,  3819,  3841,  3845,  3851,  3856,  3880,  3922,  3938,  3970,  3993,  4032,
+    };
+
     const int kmap_size = 4096;
-    const int nwant = 2;
-    const uint16_t * kgrid = kgrid_256;
+    const int nwant = grid_size == 256 ? 2 : 3;
+    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
     uint32_t * kgrid_q3xs;
     int      * kmap_q3xs;
     uint16_t * kneighbors_q3xs;
@@ -10377,7 +10722,7 @@ void iq3xs_init_impl(int grid_size) {
 }
 
 void iq3xs_free_impl(int grid_size) {
-    GGML_ASSERT(grid_size == 256);
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
     const int gindex = iq3_data_index(grid_size);
     if (iq3_data[gindex].grid) {
         free(iq3_data[gindex].grid);       iq3_data[gindex].grid = NULL;
@@ -10410,9 +10755,10 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
+        const float * restrict quant_weights) {
 
-    const int gindex = iq3_data_index(256);
+    const int gindex = iq3_data_index(grid_size);
 
     const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
     const int      * kmap_q3xs       = iq3_data[gindex].map;
@@ -10426,9 +10772,23 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
 
     const int kMaxQ = 8;
 
-    const int nbl = n/256;
+    const int nbl = n/QK_K;
 
-    block_iq3_xxs * y = vy;
+    ggml_fp16_t * dh;
+    uint8_t * qs;
+    int block_size;
+    if (grid_size == 256) {
+        block_iq3_xxs * y = vy;
+        dh = &y->d;
+        qs = y->qs;
+        block_size = sizeof(block_iq3_xxs);
+    } else {
+        block_iq3_s * y = vy;
+        dh = &y->d;
+        qs = y->qs;
+        block_size = sizeof(block_iq3_s);
+    }
+    int quant_size = block_size - sizeof(ggml_fp16_t);
 
     float scales[QK_K/32];
     float weight[32];
@@ -10439,20 +10799,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
     bool   is_on_grid[8];
     bool   is_on_grid_aux[8];
     uint8_t block_signs[8];
-    uint8_t q3[3*(QK_K/8)];
+    uint8_t q3[3*(QK_K/8)+QK_K/32];
     uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
+    uint8_t  * qh = q3 + 3*(QK_K/8);
 
     for (int ibl = 0; ibl < nbl; ++ibl) {
 
-        y[ibl].d = GGML_FP32_TO_FP16(0.f);
-        memset(q3, 0, 3*QK_K/8);
+        dh[0] = GGML_FP32_TO_FP16(0.f);
+        memset(q3, 0, 3*QK_K/8+QK_K/32);
 
         float max_scale = 0;
 
         const float * xbl = x + QK_K*ibl;
         float sumx2 = 0;
         for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
-        float sigma2 = sumx2/QK_K;
+        float sigma2 = 2*sumx2/QK_K;
 
         for (int ib = 0; ib < QK_K/32; ++ib) {
             const float * xb = xbl + 32*ib;
@@ -10570,7 +10931,13 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
                     printf("\n");
                     GGML_ASSERT(false);
                 }
-                q3[8*ib+k] = grid_index;
+                if (grid_size == 256) {
+                    q3[8*ib+k] = grid_index;
+                } else {
+                    q3[8*ib+k] = grid_index & 255;
+                    qh[ib] |= ((grid_index >> 8) << k);
+                }
+
             }
             scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
             GGML_ASSERT(scale >= 0);
@@ -10579,63 +10946,25 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
         }
 
         if (!max_scale) {
-            memset(y[ibl].qs, 0, 3*QK_K/8);
+            memset(qs, 0, quant_size);
+            dh += block_size/sizeof(ggml_fp16_t);
+            qs += block_size;
             continue;
         }
 
         float d = max_scale/31;
-        y[ibl].d = GGML_FP32_TO_FP16(d);
+        dh[0] = GGML_FP32_TO_FP16(d * 1.0125f);  // small improvement via this fudge factor
         float id = 1/d;
-        float sumqx = 0, sumq2 = 0;
         for (int ib = 0; ib < QK_K/32; ++ib) {
             int l = nearest_int(0.5f*(id*scales[ib]-1));
             l = MAX(0, MIN(15, l));
             scales_and_signs[ib] |= ((uint32_t)l << 28);
-            if (false) {
-                const float * xb = xbl + 32*ib;
-                if (quant_weights) {
-                    const float * qw = quant_weights + QK_K*ibl + 32*ib;
-                    for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
-                } else {
-                    for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
-                }
-                const float db = 0.25f * d * (1 + 2*l);
-                for (int k = 0; k < 8; ++k) {
-                    const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
-                    const float * xk = xb + 4*k;
-                    const float * wk = weight + 4*k;
-                    //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
-                    const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
-                    float best_mse = 0; int best_index = q3[8*ib+k];
-                    for (int j = 0; j < 4; ++j) {
-                        float diff = db * grid[j] * signs[j] - xk[j];
-                        best_mse += wk[j] * diff * diff;
-                    }
-                    for (int idx = 0; idx < 256; ++idx) {
-                        //grid = (const uint8_t *)(kgrid_q3xs + idx);
-                        grid = (const uint8_t *)(iq3xxs_grid + idx);
-                        float mse = 0;
-                        for (int j = 0; j < 4; ++j) {
-                            float diff = db * grid[j] * signs[j] - xk[j];
-                            mse += wk[j] * diff * diff;
-                        }
-                        if (mse < best_mse) {
-                            best_mse = mse; best_index = idx;
-                        }
-                    }
-                    q3[8*ib+k] = best_index;
-                    //grid = (const uint8_t *)(kgrid_q3xs + best_index);
-                    grid = (const uint8_t *)(iq3xxs_grid + best_index);
-                    for (int j = 0; j < 4; ++j) {
-                        float q = db * grid[j] * signs[j];
-                        sumqx += wk[j] * q * xk[j];
-                        sumq2 += wk[j] * q * q;
-                    }
-                }
-                if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
-            }
         }
-        memcpy(y[ibl].qs, q3, 3*QK_K/8);
+        memcpy(qs, q3, quant_size);
+
+        dh += block_size/sizeof(ggml_fp16_t);
+        qs += block_size;
+
     }
 }
 
@@ -10645,7 +10974,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row,
     int nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
-        quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
+        quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq3_xxs);
     }
@@ -10660,9 +10989,226 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
 
 void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
     assert(k % QK_K == 0);
-    quantize_row_iq3_xxs_impl(x, y, k, NULL);
+    quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
+}
+
+static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
+        const float * restrict quant_weights,
+        float   * scales,
+        float   * weight,
+        float   * xval,
+        int8_t  * L,
+        int8_t  * Laux,
+        float   * waux,
+        bool    * is_on_grid,
+        bool    * is_on_grid_aux,
+        uint8_t * block_signs) {
+
+    const int gindex = iq3_data_index(512);
+
+    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
+    const int      * kmap_q3xs       = iq3_data[gindex].map;
+    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 8;
+
+    const int nbl = n/QK_K;
+
+    block_iq3_s * y = vy;
+
+    const int bs4 = block_size/4;
+    const int bs8 = block_size/8;
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        memset(&y[ibl], 0, sizeof(block_iq3_s));
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+        uint8_t * qs = y[ibl].qs;
+        uint8_t * qh = y[ibl].qh;
+        uint8_t * signs = y[ibl].signs;
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < bs8; ++k) {
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
+                    }
+                }
+                block_signs[k] = s;
+            }
+            float max = xval[0];
+            for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int is = -15; is <= 15; ++is) {
+                float id = (2*kMaxQ-1+is*0.2f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < bs4; ++k) {
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+                    int grid_index = kmap_q3xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < block_size; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
+                    for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < bs4; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 3*i);
+                    }
+                    int grid_index = kmap_q3xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < block_size; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
+            }
+            for (int k = 0; k < bs4; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+                int grid_index = kmap_q3xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                qs[k] = grid_index & 255;
+                qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
+            }
+            qs += bs4;
+            for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
+            signs += bs8;
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/block_size; ib += 2) {
+            int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
+            l1 = MAX(0, MIN(15, l1));
+            int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
+            l2 = MAX(0, MIN(15, l2));
+            y[ibl].scales[ib/2] = l1 | (l2 << 4);
+        }
+
+    }
+}
+
+#define IQ3S_BLOCK_SIZE 32
+size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    float scales[QK_K/IQ3S_BLOCK_SIZE];
+    float weight[IQ3S_BLOCK_SIZE];
+    float xval[IQ3S_BLOCK_SIZE];
+    int8_t L[IQ3S_BLOCK_SIZE];
+    int8_t Laux[IQ3S_BLOCK_SIZE];
+    float  waux[IQ3S_BLOCK_SIZE];
+    bool   is_on_grid[IQ3S_BLOCK_SIZE/4];
+    bool   is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
+    uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
+                scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq3_s);
+    }
+    return nrow * nblock * sizeof(block_iq3_s);
+}
+
+void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_iq3_s * restrict y = vy;
+    quantize_row_iq3_s_reference(x, y, k);
 }
 
+void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
+    assert(k % QK_K == 0);
+    quantize_iq3_s(x, y, 1, k, NULL, NULL);
+}
+
+
 // =================================== 1.5 bpw ===================================================
 
 static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
index 113623b62938a48e9170944894f6279d933c2f2f..303b0b6f9552eb077700bf6a0f078d857561e09e 100644 (file)
@@ -191,6 +191,21 @@ typedef struct {
 } block_iq3_xxs;
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
+// 3.4375 bpw
+#if QK_K == 64
+#define IQ3S_N_SCALE 2
+#else
+#define IQ3S_N_SCALE QK_K/64
+#endif
+typedef struct {
+    ggml_fp16_t d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
 typedef struct {
     ggml_fp16_t d;
     uint8_t qs[QK_K/8];
@@ -226,6 +241,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
 void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
 void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int k);
+void quantize_row_iq3_s_reference  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int k);
 
 void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
@@ -242,6 +258,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+void quantize_row_iq3_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
 // Dequantization
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@@ -262,6 +279,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
 // Dot product
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -280,6 +298,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 //
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -289,6 +308,7 @@ size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row,
 size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq1_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq3_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q3_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q4_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
diff --git a/ggml.c b/ggml.c
index 938fedd689b84ec9a7911ab51a6913527a52e084..e3d8303cb6cb8d7ab88c6b8f222d85bcc7aa1d0b 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -682,6 +682,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ3_S] = {
+        .type_name                = "iq3_s",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq3_s),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_s,
+        .from_float               = quantize_row_iq3_s,
+        .from_float_reference     = (ggml_from_float_t)quantize_row_iq3_s_reference,
+        .vec_dot                  = ggml_vec_dot_iq3_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_IQ1_S] = {
         .type_name                = "iq1_s",
         .blck_size                = QK_K,
@@ -2308,6 +2320,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
         case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
+        case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -7742,6 +7755,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             {
                 ggml_compute_forward_add_q_f32(params, dst);
             } break;
@@ -8021,6 +8035,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             {
                 ggml_compute_forward_add1_q_f32(params, dst);
             } break;
@@ -8145,6 +8160,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
         default:
             {
                 GGML_ASSERT(false);
@@ -11043,6 +11059,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             {
                 ggml_compute_forward_out_prod_q_f32(params, dst);
             } break;
@@ -11231,6 +11248,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
         default:
             {
                 GGML_ASSERT(false);
@@ -11433,6 +11451,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
             {
                 ggml_compute_forward_get_rows_q(params, dst);
             } break;
@@ -12133,6 +12152,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -12216,6 +12236,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ3_S:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -19467,6 +19488,7 @@ void ggml_quantize_init(enum ggml_type type) {
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ1_S:   iq2xs_init_impl(type); break;
         case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
+        case GGML_TYPE_IQ3_S:   iq3xs_init_impl(512); break;
         default: // nothing
             break;
     }
@@ -19741,6 +19763,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
                 GGML_ASSERT(result == row_size * nrows);
             } break;
+        case GGML_TYPE_IQ3_S:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
+            } break;
         case GGML_TYPE_IQ1_S:
             {
                 GGML_ASSERT(start % QK_K == 0);
diff --git a/ggml.h b/ggml.h
index d918e7826f3f4aa37bf89055ffd9f4e09de86291..d37186d8edf62d83bb8072629ba5fa3aa3b5fc74 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -350,6 +350,7 @@ extern "C" {
         GGML_TYPE_IQ3_XXS = 18,
         GGML_TYPE_IQ1_S   = 19,
         GGML_TYPE_IQ4_NL  = 20,
+        GGML_TYPE_IQ3_S   = 21,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -389,6 +390,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
     };
 
     // available tensor operations: