]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
1.5 bit quantization (llama/5453)
authorKawrakow <redacted>
Sun, 18 Feb 2024 16:16:55 +0000 (18:16 +0200)
committerGeorgi Gerganov <redacted>
Mon, 19 Feb 2024 13:53:23 +0000 (15:53 +0200)
* iq1_s: WIP basics

* iq1_s: CUDA is working

* iq1_s: scalar CPU dot product

* iq1_s: WIP AVX2 dot product - something is not right

* Fix tests

* Fix shadow warnings

* Fix after merge with latest master

* iq1_s: AVX2 finally works

* iq1_s: ARM_NEON dot product. Works, but not very fast

* iq1_s: better grid

* iq1_s: use IQ2_XXS for attn_output

At a cost of 0.04 extra bpw this gives a big improvement in PPL.

* iq1_s: Metal basics

Dequantize works, but not dot product

* iq1_s: Metal works, but quite slow

As usual, Apple Silicon does not like the code I write.

* iq1_s: Tests

* iq1_s: slightly faster dot product

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-backend.c
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index d019d813ad5f070f5745515fbf9d8c9172172cd4..0a1fa9e720cf97cec1fae39b41dcb44d826aacad 100644 (file)
@@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
 GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_CPY:
-            return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
+            return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
         default:
index 5fd8a87e4150f0f06bd23884e94a81b1ebc632e2..933ebbc4eff72f1151e1c908a717fcee38f67166 100644 (file)
@@ -517,6 +517,15 @@ typedef struct {
 } block_iq3_xxs;
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
+#define QR1_S 8
+#define QI1_S (QK_K / (4*QR1_S))
+typedef struct {
+    half d;
+    uint8_t qs[QK_K/8];
+    uint8_t scales[QK_K/16];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -1681,6 +1690,137 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+static const __device__ uint64_t iq1s_grid[512] = {
+    0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
+    0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
+    0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
+    0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
+    0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
+    0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
+    0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
+    0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
+    0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
+    0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
+    0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
+    0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
+    0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
+    0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
+    0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
+    0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
+    0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
+    0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
+    0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
+    0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
+    0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
+    0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
+    0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
+    0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
+    0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
+    0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
+    0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
+    0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
+    0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
+    0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
+    0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
+    0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
+    0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
+    0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
+    0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
+    0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
+    0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
+    0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
+    0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
+    0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
+    0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
+    0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
+    0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
+    0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
+    0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
+    0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
+    0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
+    0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
+    0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
+    0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
+    0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
+    0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
+    0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
+    0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
+    0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
+    0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
+    0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
+    0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
+    0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
+    0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
+    0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
+    0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
+    0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
+    0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
+    0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
+    0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
+    0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
+    0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
+    0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
+    0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
+    0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
+    0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
+    0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
+    0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
+    0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
+    0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
+    0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
+    0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
+    0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
+    0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
+    0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
+    0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
+    0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
+    0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
+    0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
+    0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
+    0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
+    0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
+    0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
+    0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
+    0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
+    0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
+    0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
+    0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
+    0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
+    0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
+    0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
+    0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
+    0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
+    0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
+    0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
+    0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
+    0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
+    0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
+    0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
+    0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
+    0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
+    0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
+    0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
+    0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
+    0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
+    0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
+    0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
+    0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
+    0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
+    0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
+    0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
+    0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
+    0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
+    0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
+    0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
+    0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
+};
+
 static const __device__ uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
     144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
@@ -1823,6 +1963,29 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
 
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq1_s * x = (const block_iq1_s  *) vx;
+
+    const int tid = threadIdx.x;
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const int i8 = 4*ib+il;
+    uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
+    const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
+    const float d = (float)x[i].d * (2*(h & 7) + 1);
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
+#else
+    assert(false);
+#endif
+
+}
+
+
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -4522,6 +4685,49 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #endif
 }
 
+static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if QK_K == 256
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+
+    const int ib32 = iqs;
+    int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+    const uint8_t h1 = bq1->scales[2*ib32+0];
+    const uint8_t h2 = bq1->scales[2*ib32+1];
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const int * q8 = (const int *)bq8_1[ib32].qs;
+    const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
+    const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
+    const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
+    const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
+    for (int j = 0; j < 2; ++j) {
+        sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
+        sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
+        sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
+        sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
+    }
+#else
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
+    const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
+    const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
+    const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
+    for (int j = 0; j < 8; ++j) {
+        sumi1 += q8[j+ 0] * grid1[j];
+        sumi2 += q8[j+ 8] * grid2[j];
+        sumi3 += q8[j+16] * grid3[j];
+        sumi4 += q8[j+24] * grid4[j];
+    }
+#endif
+    const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
+    return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
+                sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __device__ __forceinline__ void mul_mat_q(
@@ -6561,6 +6767,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
     dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6600,6 +6812,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq2_xs_cuda;
         case GGML_TYPE_IQ3_XXS:
             return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cuda<float>;
         default:
@@ -6635,6 +6849,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq2_xs_cuda;
         case GGML_TYPE_IQ3_XXS:
             return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half>;
         default:
@@ -8378,6 +8594,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         default:
             GGML_ASSERT(false);
@@ -8401,6 +8618,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
@@ -8498,6 +8716,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
             mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
                 (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -11214,7 +11436,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                     return false;
                 }
                 ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
+                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index cf4bb395ff0f752585f90b074876c6790695f344..956e323a0d053cb75fd57c8dead79e72b8603317 100644 (file)
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +84,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +103,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +119,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +135,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
     GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -442,6 +447,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
@@ -464,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
@@ -482,6 +489,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
@@ -497,6 +505,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
@@ -512,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
@@ -1327,6 +1337,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+                                case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                             }
 
@@ -1461,6 +1472,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ1_S:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1495,7 +1512,7 @@ static bool ggml_metal_graph_compute(
 
                             if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                 src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
-                                src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
+                                src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1601,6 +1618,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+                                case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
                             }
 
@@ -1738,6 +1756,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ1_S:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1788,7 +1812,7 @@ static bool ggml_metal_graph_compute(
 
                             if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
                                 src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
-                                src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+                                src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1842,6 +1866,7 @@ static bool ggml_metal_graph_compute(
                             case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
                             case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
                             case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+                            case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
                             default: GGML_ASSERT(false && "not implemented");
                         }
index 09ebcc9e3040fdaa3a46bdb582c950ac14a28ba4..a009621113893eb66ccb119bad8ca8c58b7460fc 100644 (file)
@@ -2525,6 +2525,13 @@ typedef struct {
 } block_iq3_xxs;
 // 98 bytes / block for QK_K = 256, so 3.0625 bpw
 
+typedef struct {
+    half d;
+    uint8_t qs[QK_K/8];
+    uint8_t scales[QK_K/16];
+} block_iq1_s;
+
+
 //====================================== dot products =========================
 
 void kernel_mul_mv_q2_K_f32_impl(
@@ -3782,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+#define NGRID_IQ1S 512
+constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
+    0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
+    0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
+    0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
+    0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
+    0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
+    0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
+    0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
+    0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
+    0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
+    0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
+    0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
+    0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
+    0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
+    0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
+    0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
+    0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
+    0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
+    0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
+    0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
+    0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
+    0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
+    0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
+    0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
+    0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
+    0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
+    0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
+    0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
+    0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
+    0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
+    0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
+    0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
+    0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
+    0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
+    0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
+    0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
+    0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
+    0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
+    0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
+    0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
+    0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
+    0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
+    0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
+    0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
+    0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
+    0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
+    0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
+    0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
+    0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
+    0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
+    0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
+    0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
+    0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
+    0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
+    0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
+    0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
+    0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
+    0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
+    0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
+    0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
+    0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
+    0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
+    0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
+    0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
+    0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
+    0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
+    0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
+    0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
+    0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
+    0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
+    0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
+    0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
+    0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
+    0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
+    0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
+    0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
+    0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
+    0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
+    0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
+    0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
+    0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
+    0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
+    0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
+    0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
+    0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
+    0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
+    0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
+    0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
+    0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
+    0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
+    0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
+    0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
+    0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
+    0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
+    0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
+    0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
+    0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
+    0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
+    0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
+    0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
+    0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
+    0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
+    0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
+    0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
+    0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
+    0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
+    0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
+    0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
+    0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
+    0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
+    0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
+    0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
+    0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
+    0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
+    0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
+    0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
+    0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
+    0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
+    0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
+    0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
+    0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
+    0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
+    0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
+};
 
 constexpr constant static uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
@@ -4208,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
     kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
+void kernel_mul_mv_iq1_s_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+#if QK_K == 256
+    const int ix = tiisg/2;
+    const int il = tiisg%2;
+
+    device const float * y4 = y + 32 * ix + 16 * il;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
+
+        for (int i = 0; i < 16; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
+        device const uint8_t * sc = xr->scales + 2 * ib + il;
+        device const half    * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
+            constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
+
+            float2 sum = {0};
+            for (int j = 0; j < 8; ++j) {
+                sum[0] += yl[j+ 0] * grid1[j];
+                sum[1] += yl[j+ 8] * grid2[j];
+            }
+            sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
+
+            dh += nb*sizeof(block_iq1_s)/2;
+            qs += nb*sizeof(block_iq1_s);
+            sc += nb*sizeof(block_iq1_s);
+        }
+
+        y4 += 16 * 32;
+    }
+#else
+    // TODO
+#endif
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 
 //============================= templates and their specializations =============================
 
@@ -4553,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
     }
 }
 
+template <typename type4x4>
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    device const uint8_t * qs = xb->qs + 2*il;
+    device const uint8_t * sc = xb->scales + il;
+    const float dl1 = d * (2*(sc[0] & 7) + 1);
+    const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
+    constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
+    constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4+0][i%4] = dl1 * grid1[i];
+        reg[i/4+2][i%4] = dl2 * grid2[i];
+    }
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
@@ -5095,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 
 //
 // matrix-matrix multiplication
@@ -5134,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 
 //
 // indirect matrix-matrix multiplication
@@ -5185,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
 
 //
 // matrix-vector multiplication
@@ -6152,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
         tiisg,
         sgitg);
 }
+
+[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
+kernel void kernel_mul_mv_id_iq1_s_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq1_s_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index f44377f455a91ccc2a9982f53a92520a1c02117b..48f5294e1e78831869c7ec474807c0654daa068f 100644 (file)
@@ -3480,6 +3480,139 @@ static const uint32_t iq3xxs_grid[256] = {
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
 };
 
+#define NGRID_IQ2XXS 512
+static const  uint64_t iq1s_grid[NGRID_IQ2XXS] = {
+    0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
+    0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
+    0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
+    0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
+    0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
+    0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
+    0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
+    0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
+    0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
+    0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
+    0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
+    0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
+    0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
+    0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
+    0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
+    0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
+    0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
+    0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
+    0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
+    0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
+    0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
+    0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
+    0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
+    0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
+    0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
+    0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
+    0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
+    0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
+    0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
+    0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
+    0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
+    0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
+    0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
+    0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
+    0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
+    0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
+    0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
+    0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
+    0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
+    0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
+    0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
+    0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
+    0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
+    0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
+    0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
+    0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
+    0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
+    0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
+    0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
+    0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
+    0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
+    0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
+    0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
+    0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
+    0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
+    0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
+    0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
+    0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
+    0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
+    0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
+    0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
+    0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
+    0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
+    0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
+    0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
+    0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
+    0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
+    0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
+    0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
+    0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
+    0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
+    0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
+    0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
+    0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
+    0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
+    0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
+    0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
+    0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
+    0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
+    0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
+    0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
+    0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
+    0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
+    0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
+    0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
+    0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
+    0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
+    0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
+    0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
+    0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
+    0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
+    0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
+    0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
+    0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
+    0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
+    0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
+    0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
+    0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
+    0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
+    0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
+    0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
+    0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
+    0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
+    0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
+    0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
+    0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
+    0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
+    0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
+    0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
+    0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
+    0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
+    0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
+    0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
+    0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
+    0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
+    0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
+    0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
+    0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
+    0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
+    0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
+    0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
+    0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
+
+};
+
 static const uint8_t ksigns_iq2xs[128] = {
       0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
     144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
@@ -3578,6 +3711,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
     }
 }
 
+// ====================== 1.5625 bpw (de)-quantization
+
+void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    float db[4];
+    uint16_t idx[4];
+    //const int8_t * grid[4];
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * sc = x[i].scales;
+        const uint8_t * qs = x[i].qs;
+
+        for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
+            idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
+            idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
+            idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
+            idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
+            //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
+            //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
+            //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
+            //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
+            db[0] = d * (2*(sc[0] & 7) + 1);
+            db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
+            db[2] = d * (2*(sc[1] & 7) + 1);
+            db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                for (int j = 0; j < 8; ++j) {
+                    //y[j] = db[l] * grid[l][j];
+                    y[j] = db[l] * grid[j];
+                }
+                y += 8;
+            }
+            qs += 4;
+            sc += 2;
+        }
+    }
+}
+
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -3679,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
 }
 #endif
 
-void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -3690,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     assert(nrc == 1);
 #endif
     UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
+    UNUSED(bbx);
+    UNUSED(bby);
     UNUSED(bs);
 
     const block_q4_0 * restrict x = vx;
@@ -4046,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -4057,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     assert(nrc == 1);
 #endif
     UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
+    UNUSED(bbx);
+    UNUSED(bby);
     UNUSED(bs);
 
     const block_q4_1 * restrict x = vx;
@@ -4264,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -4272,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     assert(qk == QK5_0);
     assert(nrc == 1);
     UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
+    UNUSED(bbx);
+    UNUSED(bby);
     UNUSED(bs);
 
     const block_q5_0 * restrict x = vx;
@@ -4555,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -4563,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     assert(qk == QK5_1);
     assert(nrc == 1);
     UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
+    UNUSED(bbx);
+    UNUSED(bby);
     UNUSED(bs);
 
     const block_q5_1 * restrict x = vx;
@@ -4859,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -4870,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     assert(nrc == 1);
 #endif
     UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
+    UNUSED(bbx);
+    UNUSED(bby);
     UNUSED(bs);
 
     const block_q8_0 * restrict x = vx;
@@ -9107,6 +9283,178 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 #endif
 }
 
+#ifdef __AVX2__
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return _mm256_maddubs_epi16(ax, sy);
+}
+#endif
+
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+
+    const uint8x16_t m8 = vdupq_n_u8(0x08);
+    const uint8x16_t m7 = vdupq_n_u8(0x07);
+    const uint8x16_t m1 = vdupq_n_u8(0x01);
+    const int32x4_t vzero = vdupq_n_s32(0);
+
+    uint16_t gindex[8];
+    uint16x8x2_t vindex;
+    int8x16x4_t q1b;
+    int8x16x4_t q8b;
+    uint16x8x4_t scales;
+    int32x4x2_t sumi;
+    int32x4x2_t dotq;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t  * q8 = y[i].qs;
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        sumi.val[0] = sumi.val[1] = vzero;
+
+        for (int i128 = 0; i128 < QK_K/128; ++i128) {
+            const uint8x16_t ql = vld1q_u8(qs); qs += 16;
+            const uint8x8_t tm1 = vld1_u8 (sc); sc +=  8;
+            const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
+            const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
+            const uint8x16_t hbit = vandq_u8(qh, m8);
+            vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
+            vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
+            const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
+            scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
+            scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
+
+            for (int l = 0; l < 2; ++l) {
+                vst1q_u16(gindex+0, vindex.val[l]);
+                q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
+                q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
+                q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
+                q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
+                q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+                dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
+                dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
+
+                sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
+                sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
+            }
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m128i m8 = _mm_set1_epi8(0x08);
+    const __m128i m7 = _mm_set1_epi8(0x07);
+    const __m128i m1 = _mm_set1_epi8(0x01);
+    const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
+    const __m128i shuffle_s[4] = {
+        _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
+        _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
+        _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
+        _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
+    };
+
+    uint64_t aux64;
+
+    __m256i v_gindex;
+    const uint16_t * gindex = (const uint16_t *)&v_gindex;
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t  * q8 = y[i].qs;
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        __m256i sumi = _mm256_setzero_si256();
+        for (int i128 = 0; i128 < QK_K/128; ++i128) {
+            const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
+            memcpy(&aux64, sc, 8); sc += 8;
+            const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
+            const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
+            v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
+            const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
+
+            for (int i32 = 0; i32 < 4; ++i32) {
+                const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+                const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
+                                                      iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
+                const __m256i dot = mul_add_epi8(q1b, q8b);
+                const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
+                const __m256i p   = _mm256_madd_epi16(s16, dot);
+                sumi = _mm256_add_epi32(sumi, p);
+            }
+
+        }
+
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
+
+    }
+
+    *s = hsum_float_8(accum);
+
+#else
+
+    int db[4];
+    uint16_t idx[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t  * q8 = y[i].qs;
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int sumi = 0;
+        for (int i32 = 0; i32 < QK_K/32; ++i32) {
+            idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
+            idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
+            idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
+            idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
+            db[0] = (2*(sc[0] & 7) + 1);
+            db[1] = (2*((sc[0] >> 4) & 7) + 1);
+            db[2] = (2*(sc[1] & 7) + 1);
+            db[3] = (2*((sc[1] >> 4) & 7) + 1);
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                int suml = 0;
+                for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
+                sumi += db[l] * suml;
+                q8 += 8;
+            }
+            qs += 4;
+            sc += 2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
+    }
+
+    *s = sumf;
+
+#endif
+
+}
+
 // ================================ IQ2 quantization =============================================
 
 typedef struct {
@@ -9115,14 +9463,22 @@ typedef struct {
     uint16_t * neighbours;
 } iq2_entry_t;
 
-static iq2_entry_t iq2_data[2] = {
+static iq2_entry_t iq2_data[3] = {
+    {NULL, NULL, NULL},
     {NULL, NULL, NULL},
     {NULL, NULL, NULL},
 };
 
-static inline int iq2_data_index(int grid_size) {
-    GGML_ASSERT(grid_size == 256 || grid_size == 512);
-    return grid_size == 256 ? 0 : 1;
+static inline int iq2_data_index(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
+    return type == GGML_TYPE_IQ2_XXS ? 0 :
+           type == GGML_TYPE_IQ2_XS  ? 1 : 2;
+}
+
+static inline int iq2_grid_size(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
+    return type == GGML_TYPE_IQ2_XXS ? 256 :
+           type == GGML_TYPE_IQ2_XS  ? 512 : 512;
 }
 
 static int iq2_compare_func(const void * left, const void * right) {
@@ -9131,12 +9487,13 @@ static int iq2_compare_func(const void * left, const void * right) {
     return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
 }
 
-void iq2xs_init_impl(int grid_size) {
-    const int gindex = iq2_data_index(grid_size);
+void iq2xs_init_impl(enum ggml_type type) {
+    const int gindex = iq2_data_index(type);
+    const int grid_size = iq2_grid_size(type);
     if (iq2_data[gindex].grid) {
         return;
     }
-    static const uint16_t kgrid_256[256] = {
+    static const uint16_t kgrid_2bit_256[256] = {
             0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
           100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
          1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
@@ -9154,7 +9511,7 @@ void iq2xs_init_impl(int grid_size) {
         33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
         37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
     };
-    static const uint16_t kgrid_512[512] = {
+    static const uint16_t kgrid_2bit_512[512] = {
             0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
            73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
           260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
@@ -9188,9 +9545,45 @@ void iq2xs_init_impl(int grid_size) {
         40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
         42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
     };
+    static const uint16_t kgrid_1bit_512[512] = {
+           10,    33,    41,    85,   132,   134,   160,   162,   277,   337,   340,   345,   357,   405,   516,   545,
+          553,   598,   641,   650,   681,  1042,  1044,  1097,  1169,  1176,  1320,  1345,  1365,  1378,  1434,  1444,
+         1545,  1617,  1642,  1685,  2053,  2080,  2089,  2133,  2176,  2182,  2208,  2214,  2306,  2384,  2393,  2440,
+         2453,  2581,  2664,  2690,  2721,  4117,  4161,  4182,  4184,  4261,  4357,  4369,  4372,  4377,  4390,  4422,
+         4432,  4437,  4449,  4457,  4485,  4497,  4505,  4629,  4677,  4696,  4774,  5205,  5217,  5225,  5386,  5397,
+         5409,  5445,  5457,  5460,  5461,  5462,  5465,  5472,  5477,  5525,  5545,  5650,  5668,  5717,  5729,  5769,
+         5777,  6212,  6234,  6244,  6293,  6424,  6482,  6485,  6502,  6505,  6529,  6538,  6565,  6656,  6682,  6788,
+         6806,  6820,  8218,  8224,  8226,  8232,  8277,  8326,  8354,  8469,  8521,  8530,  8549,  8596,  8737,  8794,
+         9221,  9253,  9348,  9369,  9380,  9474,  9557,  9633,  9732,  9753,  9793,  9830,  9862,  9880, 10240, 10272,
+        10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
+        16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
+        17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
+        18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
+        20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
+        20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
+        21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
+        21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
+        21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
+        22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
+        23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
+        25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
+        25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
+        26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
+        32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
+        33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
+        34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
+        37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
+        38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
+        38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
+        39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
+        41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
+        42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
+    };
+
     const int kmap_size = 43692;
-    const int nwant = 2;
-    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
+    const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
+    const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
+                             type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 : kgrid_1bit_512;
     uint64_t * kgrid_q2xs;
     int      * kmap_q2xs;
     uint16_t * kneighbors_q2xs;
@@ -9286,9 +9679,9 @@ void iq2xs_init_impl(int grid_size) {
     free(dist2);
 }
 
-void iq2xs_free_impl(int grid_size) {
-    GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
-    const int gindex = iq2_data_index(grid_size);
+void iq2xs_free_impl(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
+    const int gindex = iq2_data_index(type);
     if (iq2_data[gindex].grid) {
         free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
         free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;
@@ -9322,7 +9715,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
 
 static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
 
-    const int gindex = iq2_data_index(256);
+    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
 
     const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
     const int      * kmap_q2xs       = iq2_data[gindex].map;
@@ -9495,7 +9888,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
 
 static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
 
-    const int gindex = iq2_data_index(512);
+    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
 
     const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
     const int      * kmap_q2xs       = iq2_data[gindex].map;
@@ -10132,3 +10525,207 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
     assert(k % QK_K == 0);
     quantize_row_iq3_xxs_impl(x, y, k, NULL);
 }
+
+// =================================== 1.5 bpw ===================================================
+
+static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_score = 0;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float sumqx = 0, sumq2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = (pg[i] - 3)/2;
+            float w = weight[i];
+            sumqx += w*q*xval[i];
+            sumq2 += w*q*q;
+        }
+        if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+            *scale = sumqx/sumq2; best_score = *scale * sumqx;
+            grid_index = neighbours[j];
+        }
+    }
+    if (grid_index < 0) {
+        for (int i = 0; i < ngrid; ++i) {
+            const int8_t * grid_i = (const int8_t *)(grid + i);
+            float sumqx = 0, sumq2 = 0;
+            for (int j = 0; j < 8; ++j) {
+                float w = weight[j];
+                float q = (grid_i[j] - 3)/2;
+                sumqx += w*q*xval[j];
+                sumq2 += w*q*q;
+            }
+            if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+                *scale = sumqx/sumq2; best_score = *scale*sumqx;
+                grid_index = i;
+            }
+        }
+    }
+    if (grid_index < 0) {
+        printf("Oops, did not find grid point\n");
+        printf("Have %d neighbours\n", num_neighbors);
+        for (int j = 1; j <= num_neighbors; ++j) {
+            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+            float sumqx = 0, sumq2 = 0;
+            for (int i = 0; i < 8; ++i) {
+                float q = (pg[i] - 3)/2;
+                float w = weight[i];
+                sumqx += w*q*xval[i];
+                sumq2 += w*q*q;
+            }
+            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    *scale *= 1.05f;  // This is a fudge factor. Don't ask me why it improves the result.
+    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static int iq1_sort_helper(const void * left, const void * right) {
+    const float * l = left;
+    const float * r = right;
+    return *l < *r ? -1 : *l > *r ? 1 : 0;
+}
+
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int nbl = n/256;
+
+    block_iq1_s * y = vy;
+
+    float  scales[QK_K/8];
+    float  weight[8];
+    int8_t L[8];
+    float  sumx[9];
+    float  sumw[9];
+    float  pairs[16];
+    int * idx = (int *)(pairs + 1);
+    uint8_t hbit[QK_K/8];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(y[ibl].qs, 0, QK_K/8);
+        memset(y[ibl].scales, 0, QK_K/16);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/8; ++ib) {
+            const float * xb = xbl + 8*ib;
+            const float * qw = quant_weights + QK_K*ibl + 8*ib;
+            for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            float max = fabsf(xb[0]);
+            for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 1, 8);
+                continue;
+            }
+            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+            // for each possible and score for each split.
+            for (int j = 0; j < 8; ++j) {
+                pairs[2*j] = xb[j];
+                idx[2*j] = j;
+            }
+            qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
+            {
+                sumx[0] = sumw[0] = 0;
+                for (int j = 0; j < 8; ++j) {
+                    int i = idx[2*j];
+                    sumx[j+1] = sumx[j] + weight[i]*xb[i];
+                    sumw[j+1] = sumw[j] + weight[i];
+                }
+            }
+            float best_score = 0, scale = max;
+            int besti1 = 0, besti2 = 0;
+            for (int i1 = 0; i1 <= 8; ++i1) {
+                for (int i2 = i1; i2 <= 8; ++i2) {
+                    float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
+                    float sumq2 =  (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
+                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+                        scale = sumqx/sumq2; best_score = scale*sumqx;
+                        besti1 = i1; besti2 = i2;
+                    }
+                }
+            }
+            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
+            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+            for (int j = besti2; j <      8; ++j) L[idx[2*j]] = 2;
+            if (scale < 0) {
+                for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
+                scale = -scale;
+            }
+            // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
+            // grid point that minimizes SSD.
+            uint16_t u = 0;
+            for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
+            int grid_index = kmap_q2xs[u];
+            if (grid_index < 0) {
+                const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
+                GGML_ASSERT(grid_index >= 0);
+            }
+            y[ibl].qs[ib] = grid_index & 255;
+            hbit[ib] = grid_index >> 8;
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/8);
+            continue;
+        }
+
+        float d = max_scale/15;
+        y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/8; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(7, l));
+            if (hbit[ib]) l |= 8;
+            y[ibl].scales[ib/2] |= (l << 4*(ib%2));
+        }
+    }
+}
+
+size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq1_s);
+    }
+    return nrow * nblock * sizeof(block_iq1_s);
+}
index 68f09b1e12f256a07a3aa47c8732b5dd0b7c0171..ad381cfab022d21c48bc8f524cb78d6d565d8a57 100644 (file)
@@ -191,6 +191,13 @@ typedef struct {
 } block_iq3_xxs;
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
+typedef struct {
+    ggml_fp16_t d;
+    uint8_t qs[QK_K/8];
+    uint8_t scales[QK_K/16];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -243,6 +250,7 @@ void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRI
 void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
 // Dot product
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -259,6 +267,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 //
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -266,6 +275,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
 size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq1_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q3_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q4_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
@@ -276,8 +286,8 @@ size_t quantize_q4_1   (const float * src, void * dst, int nrows, int n_per_row,
 size_t quantize_q5_0   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q5_1   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
-void iq2xs_init_impl(int grid_size);
-void iq2xs_free_impl(int grid_size);
+void iq2xs_init_impl(enum ggml_type type);
+void iq2xs_free_impl(enum ggml_type type);
 void iq3xs_init_impl(int grid_size);
 void iq3xs_free_impl(int grid_size);
 
diff --git a/ggml.c b/ggml.c
index e94024c62a1238883e5f367d92dea8928a18e416..aefcda6d42a0dc36d9c24c2d77445b5b10e8d8b7 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -673,6 +673,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ1_S] = {
+        .type_name                = "iq1_s",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq1_s),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_s,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q8_K] = {
         .type_name                = "q8_K",
         .blck_size                = QK_K,
@@ -2267,6 +2279,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_IQ2_XXS:       wtype = GGML_TYPE_IQ2_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
         case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
+        case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -7677,6 +7690,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -7944,6 +7958,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             {
                 ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
             } break;
@@ -8064,6 +8079,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
         default:
             {
                 GGML_ASSERT(false);
@@ -10830,6 +10846,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             {
                 ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
             } break;
@@ -11010,6 +11027,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
         default:
             {
                 GGML_ASSERT(false);
@@ -11207,6 +11225,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -11880,6 +11899,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -11957,6 +11977,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -19136,8 +19157,9 @@ void ggml_quantize_init(enum ggml_type type) {
     ggml_critical_section_start();
 
     switch (type) {
-        case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
-        case GGML_TYPE_IQ2_XS:  iq2xs_init_impl(512); break;
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ1_S:   iq2xs_init_impl(type); break;
         case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
         default: // nothing
             break;
@@ -19149,8 +19171,10 @@ void ggml_quantize_init(enum ggml_type type) {
 void ggml_quantize_free(void) {
     ggml_critical_section_start();
 
-    iq2xs_free_impl(256);
-    iq2xs_free_impl(512);
+    iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
+    iq2xs_free_impl(GGML_TYPE_IQ2_XS);
+    iq2xs_free_impl(GGML_TYPE_IQ1_S);
+    iq3xs_free_impl(256);
 
     ggml_critical_section_end();
 }
@@ -19285,7 +19309,8 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
 bool ggml_quantize_requires_imatrix(enum ggml_type type) {
     return
         type == GGML_TYPE_IQ2_XXS ||
-        type == GGML_TYPE_IQ2_XS;
+        type == GGML_TYPE_IQ2_XS  ||
+        type == GGML_TYPE_IQ1_S;
 }
 
 size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
@@ -19410,6 +19435,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
                 GGML_ASSERT(result == row_size * nrows);
             } break;
+        case GGML_TYPE_IQ1_S:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
+            } break;
         case GGML_TYPE_F16:
             {
                 size_t elemsize = sizeof(ggml_fp16_t);
diff --git a/ggml.h b/ggml.h
index 6c1956772324c2daff38fcd24bba429a225256f3..004d09c700f3264d78bcbd9f2613398a21d16ba7 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -354,6 +354,7 @@ extern "C" {
         GGML_TYPE_IQ2_XXS = 16,
         GGML_TYPE_IQ2_XS  = 17,
         GGML_TYPE_IQ3_XXS = 18,
+        GGML_TYPE_IQ1_S   = 19,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -391,6 +392,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
     };
 
     // available tensor operations: