]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SOTA 3-bit quants (llama/5196)
authorKawrakow <redacted>
Tue, 30 Jan 2024 13:14:12 +0000 (15:14 +0200)
committerGeorgi Gerganov <redacted>
Tue, 30 Jan 2024 19:27:59 +0000 (21:27 +0200)
* iq3_xxs: quantize/dequantize

RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.

* iq3_xxs: CUDA dequantize works

* iq2_xxs: tuning quantization

* iq3_xxs: starting to look better

PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717

This is better than Q3_K_XS, with a 5% reduction in quantized model
size.

* iq3_xxs: CUDA dot product

We have
PP-512: 5891 t/s
TG-128: 143.9 t/s

* iq3_xxs: scalar and AVX2 dot products

* iq3_xxs: ARM_NEON and Metal

Metal performance is decent, ARM_NEON is pathetic

* iq3_xxs: slightly better grid points

* Faster iq3_xxs and iq2_xs dot products on CUDA

* iq3_xxs: add some quant mix

* iq3_xxs: fix failing quantization test

Dot product still fails. Is this real?

* iq3_xxs: hopefully fix ROCm

* iq3_xxs: failing tests

This time the dot product accuracy did find an actual bug
in the AVX2 implementation.

* Add IQ3_XXS to test-backend-ops

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index 8ff9fbd564775e5aac224291944f996dcfaae28f..7508ead3e41a52d324850d6867c0351afd82c784 100644 (file)
@@ -191,6 +191,10 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
 #endif // __has_builtin(__builtin_elementwise_sub_sat)
 }
 
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+    return __vsubss4(a, b);
+}
+
 static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -505,6 +509,14 @@ typedef struct {
 } block_iq2_xs;
 static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
 
+#define QR3_XXS 8
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+typedef struct {
+    half d;
+    uint8_t qs[3*(QK_K/8)];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -1613,6 +1625,41 @@ static const __device__ uint64_t iq2xs_grid[512] = {
     0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
 };
 
+static const __device__ uint32_t iq3xxs_grid[256] = {
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
+
 static const __device__ uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
     144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
@@ -1624,6 +1671,43 @@ static const __device__ uint8_t ksigns_iq2xs[128] = {
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 };
 
+//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+static const __device__ uint64_t ksigns64[128] = {
+    0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+    0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+    0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+    0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+    0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+    0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+    0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+    0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+    0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+    0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+    0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+    0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+    0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+    0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+    0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+    0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+    0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+    0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+    0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+    0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+    0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+    0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+    0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+    0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+    0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+    0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+    0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+    0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+    0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+    0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+    0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+    0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+};
+//#endif
+
 static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
 
 inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
@@ -1690,6 +1774,34 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
 
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
+
+    const int tid = threadIdx.x;
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t  * q3 = x[i].qs + 8*ib;
+    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -4313,6 +4425,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
 
 static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
 #if QK_K == 256
     const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
 
@@ -4323,20 +4436,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     const uint8_t ls2 = bq2->scales[ib32] >>  4;
     int sumi1 = 0;
     for (int l = 0; l < 2; ++l) {
-        const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-        const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-        for (int j = 0; j < 8; ++j) {
-            sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-        }
+        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+        const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
+        const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
+        sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+        sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
         q8 += 8;
     }
     int sumi2 = 0;
     for (int l = 2; l < 4; ++l) {
-        const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-        const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-        for (int j = 0; j < 8; ++j) {
-            sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-        }
+        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+        const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
+        const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
+        sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+        sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
         q8 += 8;
     }
     const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
@@ -4345,6 +4460,45 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     assert(false);
     return 0.f;
 #endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+
+    const int ib32 = iqs;
+    const uint8_t  * q3 = bq2->qs + 8*ib32;
+    const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    uint32_t aux32 = gas[0] | (gas[1] << 16);
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
+        const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
+        const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
+        const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
+        sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
+        sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
+        q8 += 8;
+        aux32 >>= 7;
+    }
+    const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
+    return d * sumi;
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
 }
 
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
@@ -6394,6 +6548,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
     dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6431,6 +6591,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq2_xxs_cuda;
         case GGML_TYPE_IQ2_XS:
             return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cuda<float>;
         default:
@@ -6464,6 +6626,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq2_xxs_cuda;
         case GGML_TYPE_IQ2_XS:
             return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half>;
         default:
@@ -6676,6 +6840,15 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
+static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
 static void ggml_mul_mat_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -8239,6 +8412,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         default:
             GGML_ASSERT(false);
@@ -8261,6 +8435,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
@@ -8332,6 +8507,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
             break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -10968,7 +11146,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                     return false;
                 }
                 ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
+                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index c095be3e34ff77ce4938043215f240c3689e5128..f87859552816b0f90668dd696aebf666169a937f 100644 (file)
@@ -60,6 +60,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -81,6 +82,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -98,6 +100,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -112,6 +115,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -126,6 +130,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
     GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -426,6 +431,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,             get_rows_q6_K,          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
@@ -447,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,           mul_mv_q6_K_f32,        ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
@@ -464,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,        mul_mv_id_q6_K_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
@@ -478,6 +486,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,           mul_mm_q6_K_f32,        ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
@@ -492,6 +501,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,        mul_mm_id_q6_K_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
@@ -1279,6 +1289,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+                                case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                             }
 
@@ -1407,6 +1418,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ3_XXS:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1449,6 +1466,11 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
+                            else if (src0t == GGML_TYPE_IQ3_XXS) {
+                                const int mem_size = 256*4+128;
+                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            }
                             else if (src0t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1543,6 +1565,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+                                case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
                             }
 
@@ -1674,6 +1697,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ3_XXS:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1732,6 +1761,11 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
+                            else if (src2t == GGML_TYPE_IQ3_XXS) {
+                                const int mem_size = 256*4+128;
+                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            }
                             else if (src2t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1772,6 +1806,7 @@ static bool ggml_metal_graph_compute(
                             case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
                             case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
                             case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+                            case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
                             default: GGML_ASSERT(false && "not implemented");
                         }
index 029578dc54dbd8d503fb3c92dc766f7b1726ef8a..4274106358612cb0194d689a83a4399ced03e654 100644 (file)
@@ -2459,6 +2459,12 @@ typedef struct {
 } block_iq2_xs;
 // 74 bytes / block for QK_K = 256, so 2.3125 bpw
 
+typedef struct {
+    half d;
+    uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+// 98 bytes / block for QK_K = 256, so 3.0625 bpw
+
 //====================================== dot products =========================
 
 void kernel_mul_mv_q2_K_f32_impl(
@@ -3681,6 +3687,42 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
     0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
 };
 
+constexpr constant static uint32_t iq3xxs_grid[256] = {
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3c, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043c1c, 0x04043c2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3c04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3c,
+    0x04142c0c, 0x04142c3c, 0x04143c2c, 0x041c040c, 0x041c043c, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3c04, 0x04240c1c, 0x04241c3c, 0x04242424, 0x04242c3c, 0x04243c1c, 0x04243c2c, 0x042c040c,
+    0x042c043c, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043c0c04, 0x043c0c24, 0x043c0c34,
+    0x043c241c, 0x043c340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243c, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143c14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3c0c, 0x0c34042c, 0x0c3c1414,
+    0x0c3c2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343c, 0x140c3c04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3c,
+    0x14141404, 0x14141414, 0x14141c3c, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3c04, 0x141c3c24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143c, 0x142c240c, 0x142c3c24,
+    0x143c040c, 0x143c041c, 0x143c0c34, 0x143c242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043c14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143c14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243c, 0x1c243c14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3c1c1c, 0x1c3c3404, 0x24040424, 0x24040c3c,
+    0x24041c2c, 0x24041c3c, 0x24042c1c, 0x24042c3c, 0x240c3c24, 0x24141404, 0x24141c3c, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043c, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3c04, 0x243c042c, 0x243c0c04, 0x243c0c14, 0x243c1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043c04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143c14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143c, 0x2c243c14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3c1424, 0x2c3c2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3c, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343c041c, 0x343c140c, 0x3c04041c, 0x3c04042c, 0x3c04043c, 0x3c040c04, 0x3c041c14,
+    0x3c042c14, 0x3c0c1434, 0x3c0c2404, 0x3c140c14, 0x3c14242c, 0x3c142c14, 0x3c1c0404, 0x3c1c0c2c,
+    0x3c1c1c1c, 0x3c1c3404, 0x3c24140c, 0x3c24240c, 0x3c2c0404, 0x3c2c0414, 0x3c2c1424, 0x3c341c04,
+};
+
+
 constexpr constant static uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
     144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
@@ -3970,6 +4012,143 @@ kernel void kernel_mul_mv_iq2_xs_f32(
     kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
+void kernel_mul_mv_iq3_xxs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
+    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+#if QK_K == 256
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_xxs * xr = x + ibl;
+        device const uint8_t  * q3 = xr->qs + 8 * ib;
+        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const uint32_t aux32 = gas[0] | (gas[1] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
+                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += nb*sizeof(block_iq3_xxs)/2;
+            q3  += nb*sizeof(block_iq3_xxs);
+            gas += nb*sizeof(block_iq3_xxs)/2;
+        }
+
+        y4 += 32 * 32;
+    }
+#else
+    // TODO
+#endif
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+
 //============================= templates and their specializations =============================
 
 // NOTE: this is not dequantizing - we are simply fitting the template
@@ -4287,6 +4466,33 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
     }
 }
 
+template <typename type4x4>
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * q3 = xb->qs + 8*ib32;
+    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
@@ -4828,6 +5034,7 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
 //
 // matrix-matrix multiplication
@@ -4866,6 +5073,7 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
 //
 // indirect matrix-matrix multiplication
@@ -4916,6 +5124,7 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
 template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
 //
 // matrix-vector multiplication
@@ -5818,3 +6027,68 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
         tiisg,
         sgitg);
 }
+
+[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_id_iq3_xxs_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq3_xxs_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_values,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index 7d2f033e9a0fe8b288e1183fc53c9c75b40120e5..ac061b63a7dea9101d9cd82438a82de1ce2fa81b 100644 (file)
@@ -3441,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
     0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
 };
 
+static const uint32_t iq3xxs_grid[256] = {
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
+
 static const uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
     144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
@@ -3507,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
     }
 }
 
+// ====================== 3.0625 bpw (de)-quantization
+
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint32_t aux32;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * scales_and_signs = qs + QK_K/4;
+
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
+            const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
+                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qs += 8;
+        }
+    }
+}
+
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8551,6 +8618,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
 #endif
 }
 
+// TODO
+void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_iq3_xxs * restrict x = vx;
+    const block_q8_K    * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t   * restrict q8 = y[i].qs;
+        float sumf1 = 0, sumf2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
+            const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+            const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+            const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+            const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+            q3 += 16;
+            q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >>  7) & 127))));
+            q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
+            q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
+            q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+            q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
+            q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
+            q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
+            q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
+            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
+        }
+        sumf += d*(sumf1 + sumf2);
+    }
+    *s = 0.5f * sumf;
+
+#elif defined(__AVX2__)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+    uint32_t aux32;
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
+            const uint32_t ls = 2*(aux32 >> 28) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
+                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            q3 += 8;
+            bsum += sumi * ls;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.25f * sumf;
+#endif
+}
+
 // ================================ IQ2 quantization =============================================
 
 typedef struct {
@@ -9189,3 +9386,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
     return nrow * nblock * sizeof(block_iq2_xs);
 }
 
+//
+// ============================================= 3-bit using D4 lattice
+//
+
+typedef struct {
+    uint32_t * grid;
+    int      * map;
+    uint16_t * neighbours;
+} iq3_entry_t;
+
+static iq3_entry_t iq3_data[1] = {
+    {NULL, NULL, NULL},
+};
+
+static inline int iq3_data_index(int grid_size) {
+    (void)grid_size;
+    GGML_ASSERT(grid_size == 256);
+    return 0;
+}
+
+static int iq3_compare_func(const void * left, const void * right) {
+    const int * l = (const int *)left;
+    const int * r = (const int *)right;
+    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq3xs_init_impl(int grid_size) {
+    const int gindex = iq3_data_index(grid_size);
+    if (iq3_data[gindex].grid) {
+        return;
+    }
+    static const uint16_t kgrid_256[256] = {
+            0,     2,     4,     9,    11,    15,    16,    18,    25,    34,    59,    61,    65,    67,    72,    74,
+           81,    85,    88,    90,    97,   108,   120,   128,   130,   132,   137,   144,   146,   153,   155,   159,
+          169,   175,   189,   193,   199,   200,   202,   213,   248,   267,   287,   292,   303,   315,   317,   321,
+          327,   346,   362,   413,   436,   456,   460,   462,   483,   497,   513,   515,   520,   522,   529,   531,
+          536,   538,   540,   551,   552,   576,   578,   585,   592,   594,   641,   643,   648,   650,   657,   664,
+          698,   704,   706,   720,   729,   742,   758,   769,   773,   808,   848,   852,   870,   889,   901,   978,
+          992,  1024,  1026,  1033,  1035,  1040,  1042,  1046,  1049,  1058,  1089,  1091,  1093,  1096,  1098,  1105,
+         1112,  1139,  1143,  1144,  1152,  1154,  1161,  1167,  1168,  1170,  1183,  1184,  1197,  1217,  1224,  1228,
+         1272,  1276,  1309,  1323,  1347,  1367,  1377,  1404,  1473,  1475,  1486,  1509,  1537,  1544,  1546,  1553,
+         1555,  1576,  1589,  1594,  1600,  1602,  1616,  1625,  1636,  1638,  1665,  1667,  1672,  1685,  1706,  1722,
+         1737,  1755,  1816,  1831,  1850,  1856,  1862,  1874,  1901,  1932,  1950,  1971,  2011,  2032,  2052,  2063,
+         2077,  2079,  2091,  2095,  2172,  2192,  2207,  2208,  2224,  2230,  2247,  2277,  2308,  2345,  2356,  2389,
+         2403,  2424,  2501,  2504,  2506,  2520,  2570,  2593,  2616,  2624,  2630,  2646,  2669,  2700,  2714,  2746,
+         2754,  2795,  2824,  2835,  2839,  2874,  2882,  2905,  2984,  3028,  3042,  3092,  3108,  3110,  3124,  3153,
+         3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,
+         3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,
+    };
+    const int kmap_size = 4096;
+    const int nwant = 2;
+    const uint16_t * kgrid = kgrid_256;
+    uint32_t * kgrid_q3xs;
+    int      * kmap_q3xs;
+    uint16_t * kneighbors_q3xs;
+
+    printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+    uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
+    for (int k = 0; k < grid_size; ++k) {
+        int8_t * pos = (int8_t *)(the_grid + k);
+        for (int i = 0; i < 4; ++i) {
+            int l = (kgrid[k] >> 3*i) & 0x7;
+            pos[i] = 2*l + 1;
+        }
+    }
+    kgrid_q3xs = the_grid;
+    iq3_data[gindex].grid = the_grid;
+    kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
+    iq3_data[gindex].map = kmap_q3xs;
+    for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
+    uint32_t aux32;
+    uint8_t * aux8 = (uint8_t *)&aux32;
+    for (int i = 0; i < grid_size; ++i) {
+        aux32 = kgrid_q3xs[i];
+        uint16_t index = 0;
+        for (int k=0; k<4; ++k) {
+            uint16_t q = (aux8[k] - 1)/2;
+            index |= (q << 3*k);
+        }
+        kmap_q3xs[index] = i;
+    }
+    int8_t pos[4];
+    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+    int num_neighbors = 0, num_not_in_map = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q3xs[i] >= 0) continue;
+        ++num_not_in_map;
+        for (int k = 0; k < 4; ++k) {
+            int l = (i >> 3*k) & 0x7;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+        int n = 0; int d2 = dist2[0];
+        int nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            ++n;
+        }
+        num_neighbors += n;
+    }
+    printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+    kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+    iq3_data[gindex].neighbours = kneighbors_q3xs;
+    int counter = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q3xs[i] >= 0) continue;
+        for (int k = 0; k < 4; ++k) {
+            int l = (i >> 3*k) & 0x7;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+        kmap_q3xs[i] = -(counter + 1);
+        int d2 = dist2[0];
+        uint16_t * start = &kneighbors_q3xs[counter++];
+        int n = 0, nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            kneighbors_q3xs[counter++] = dist2[2*j+1];
+            ++n;
+        }
+        *start = n;
+    }
+    free(dist2);
+}
+
+void iq3xs_free_impl(int grid_size) {
+    GGML_ASSERT(grid_size == 256);
+    const int gindex = iq3_data_index(grid_size);
+    if (iq3_data[gindex].grid) {
+        free(iq3_data[gindex].grid);       iq3_data[gindex].grid = NULL;
+        free(iq3_data[gindex].map);        iq3_data[gindex].map  = NULL;
+        free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
+    }
+}
+
+static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_d2 = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 4; ++i) {
+            float q = pg[i];
+            float diff = scale*q - xval[i];
+            d2 += weight[i]*diff*diff;
+        }
+        if (d2 < best_d2) {
+            best_d2 = d2; grid_index = neighbours[j];
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq3_data_index(256);
+
+    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
+    const int      * kmap_q3xs       = iq3_data[gindex].map;
+    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 8;
+
+    const int nbl = n/256;
+
+    block_iq3_xxs * y = vy;
+
+    float scales[QK_K/32];
+    float weight[32];
+    float xval[32];
+    int8_t L[32];
+    int8_t Laux[32];
+    float  waux[32];
+    bool   is_on_grid[8];
+    bool   is_on_grid_aux[8];
+    uint8_t block_signs[8];
+    uint8_t q3[3*(QK_K/8)];
+    uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q3, 0, 3*QK_K/8);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float * xb = xbl + 32*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + 32*ib;
+                for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 4; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int is = -15; is <= 15; ++is) {
+                float id = (2*kMaxQ-1+is*0.2f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 8; ++k) {
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+                    int grid_index = kmap_q3xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  8; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 8; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 3*i);
+                    }
+                    int grid_index = kmap_q3xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 8; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+                int grid_index = kmap_q3xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                q3[8*ib+k] = grid_index;
+            }
+            scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, 3*QK_K/8);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        float sumqx = 0, sumq2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            scales_and_signs[ib] |= ((uint32_t)l << 28);
+            if (false) {
+                const float * xb = xbl + 32*ib;
+                if (quant_weights) {
+                    const float * qw = quant_weights + QK_K*ibl + 32*ib;
+                    for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+                } else {
+                    for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+                }
+                const float db = 0.25f * d * (1 + 2*l);
+                for (int k = 0; k < 8; ++k) {
+                    const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
+                    const float * xk = xb + 4*k;
+                    const float * wk = weight + 4*k;
+                    //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
+                    const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
+                    float best_mse = 0; int best_index = q3[8*ib+k];
+                    for (int j = 0; j < 4; ++j) {
+                        float diff = db * grid[j] * signs[j] - xk[j];
+                        best_mse += wk[j] * diff * diff;
+                    }
+                    for (int idx = 0; idx < 256; ++idx) {
+                        //grid = (const uint8_t *)(kgrid_q3xs + idx);
+                        grid = (const uint8_t *)(iq3xxs_grid + idx);
+                        float mse = 0;
+                        for (int j = 0; j < 4; ++j) {
+                            float diff = db * grid[j] * signs[j] - xk[j];
+                            mse += wk[j] * diff * diff;
+                        }
+                        if (mse < best_mse) {
+                            best_mse = mse; best_index = idx;
+                        }
+                    }
+                    q3[8*ib+k] = best_index;
+                    //grid = (const uint8_t *)(kgrid_q3xs + best_index);
+                    grid = (const uint8_t *)(iq3xxs_grid + best_index);
+                    for (int j = 0; j < 4; ++j) {
+                        float q = db * grid[j] * signs[j];
+                        sumqx += wk[j] * q * xk[j];
+                        sumq2 += wk[j] * q * q;
+                    }
+                }
+                if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
+            }
+        }
+        memcpy(y[ibl].qs, q3, 3*QK_K/8);
+    }
+}
+
+size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq3_xxs);
+    }
+    return nrow * nblock * sizeof(block_iq3_xxs);
+}
+
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_iq3_xxs * restrict y = vy;
+    quantize_row_iq3_xxs_reference(x, y, k);
+}
+
+void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
+    assert(k % QK_K == 0);
+    quantize_row_iq3_xxs_impl(x, y, k, NULL);
+}
index 7d7cf9178f76e7714fc2080578342fbdd5e5ee65..5c9f63bd9b102bf72ec6584604eefe9001b225e8 100644 (file)
@@ -166,7 +166,7 @@ typedef struct {
 static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
 
 // (Almost) "true" 2-bit quantization.
-// Due to the need to use blocks as per ggml dsign, it ends up using
+// Due to the need to use blocks as per ggml design, it ends up using
 // 2.0625 bpw because of the 16-bit scale for each block of 256.
 typedef struct {
     ggml_fp16_t d;
@@ -182,6 +182,15 @@ typedef struct {
 } block_iq2_xs;
 static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
 
+// (Almost) "true" 3-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 3.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_fp16_t d;
+    uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
 // Quantization
 void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
 void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
@@ -196,6 +205,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
 void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
 void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
+void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k);
 
 void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
 void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -210,6 +220,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict y, int k);
 
 // Dequantization
 void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -227,6 +238,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
 void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
 void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
 void dequantize_row_iq2_xs (const block_iq2_xs  * restrict x, float * restrict y, int k);
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k);
 
 // Dot product
 void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -242,12 +254,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
 void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 //
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 //
 size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q3_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q4_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
@@ -260,3 +274,5 @@ size_t quantize_q5_1   (const float * src, void * dst, int nrows, int n_per_row,
 
 void iq2xs_init_impl(int grid_size);
 void iq2xs_free_impl(int grid_size);
+void iq3xs_init_impl(int grid_size);
+void iq3xs_free_impl(int grid_size);
diff --git a/ggml.c b/ggml.c
index e6dce1c457a6a2374ef8c511a7efb507d145f62b..a7a9ea319c5f09dd711d783838200d66fb2f10d7 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -632,6 +632,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
+    [GGML_TYPE_IQ3_XXS] = {
+        .type_name                = "iq3_xxs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq3_xxs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_xxs,
+        .from_float               = quantize_row_iq3_xxs,
+        .from_float_reference     = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
+        .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
     [GGML_TYPE_Q8_K] = {
         .type_name                = "q8_K",
         .blck_size                = QK_K,
@@ -2177,6 +2188,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;
         case GGML_FTYPE_MOSTLY_IQ2_XXS:       wtype = GGML_TYPE_IQ2_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
+        case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -7570,6 +7582,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -7836,6 +7849,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             {
                 ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
             } break;
@@ -7955,6 +7969,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
         default:
             {
                 GGML_ASSERT(false);
@@ -10706,6 +10721,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             {
                 ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
             } break;
@@ -10885,6 +10901,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
         default:
             {
                 GGML_ASSERT(false);
@@ -11081,6 +11098,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -11728,6 +11746,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -11804,6 +11823,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_Q6_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -18860,6 +18880,7 @@ void ggml_quantize_init(enum ggml_type type) {
     switch (type) {
         case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
         case GGML_TYPE_IQ2_XS:  iq2xs_init_impl(512); break;
+        case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
         default: // nothing
             break;
     }
@@ -19122,6 +19143,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
                 GGML_ASSERT(result == row_size * nrows);
             } break;
+        case GGML_TYPE_IQ3_XXS:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
+            } break;
         case GGML_TYPE_F16:
             {
                 size_t elemsize = sizeof(ggml_fp16_t);
diff --git a/ggml.h b/ggml.h
index d697fd2bb7c47efeef4317c9479e08004fb53691..bf782e6ad12793931e36f828bf5e01345774d6f2 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -353,6 +353,7 @@ extern "C" {
         GGML_TYPE_Q8_K = 15,
         GGML_TYPE_IQ2_XXS = 16,
         GGML_TYPE_IQ2_XS  = 17,
+        GGML_TYPE_IQ3_XXS = 18,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -389,6 +390,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
     };
 
     // available tensor operations: