]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
k-quants : support for super-block size of 64 (#2001)
authorKawrakow <redacted>
Mon, 26 Jun 2023 16:43:07 +0000 (19:43 +0300)
committerGitHub <redacted>
Mon, 26 Jun 2023 16:43:07 +0000 (19:43 +0300)
* k_quants: WIP super-blocks with 64 weights

* k_quants: WIP super-blocks with 64 weights

Q6_K scalar and AVX2 works

* k_quants: WIP super-blocks with 64 weights

Q4_K scalar and AVX2 works

* k_quants: WIP super-blocks with 64 weights

Q2_K scalar and AVX2 works. Q2_K is way too slow (it is actually slower
than the scalar implementation)

* k_quants: WIP super-blocks with 64 weights

Q3_K scalar and AVX2 works.

* k_quants: WIP super-blocks with 64 weights

Q5_K scalar and AVX2 works, and with that all
k_quants are done on AVX2 and scalar

* k_quants: WIP super-blocks with 64 weights

Q6_K working on CUDA. Cannot make it run quite as gast as
with super-blocks with 256 weigths: 8% slower on 4080,
20% slower on the 1660 (but there we fit 1 less layer on the
GPU because pf the larger model size), so some fraction of
these 20% is due to that,

* k_quants: WIP super-blocks with 64 weights

Q4_K working on CUDA. ~10% slower on GTX-1660,
16% slower on 4080.

* k_quants: WIP super-blocks with 64 weights

Q2_K working on CUDA. ~3% slower on GTX-1660,
10% slower on 4080.

* k_quants: WIP super-blocks with 64 weights

Q3_K working on CUDA.

* k_quants: WIP super-blocks with 64 weights

Q5_K working on CUDA, and with this CUDA is done.

* k_quants: WIP super-blocks with 64 weights

Q6_K working on ARM_NEON

* k_quants: WIP super-blocks with 64 weights

Q4_K working on ARM_NEON, but quite a bit slower than 256 weights

* k_quants: WIP super-blocks with 64 weights

Q2_K working on ARM_NEON, but quite a bit slower than 256 weights

* k_quants: WIP super-blocks with 64 weights

Q3_K working on ARM_NEON, but quite a bit slower than 256 weights.

* k_quants: WIP super-blocks with 64 weights

Q5_K working on ARM_NEON, but quite a bit slower than 256 weights.

With that, we have full support for ARM_NEON, although
performance is not quite there.

* k_quants: WIP super-blocks with 64 weights

Slightly more efficient Q3_K and Q5_K

* k_quants: WIP super-blocks with 64 weights

Another small improvement for Q3_K and Q5_K on ARM_NEON

* k_quants: WIP super-blocks with 64 weights

Yet another speedup for Q5_K on ARM_NEON.
We are now within 10% of the QK_K = 256 version.

* k_quants: WIP super-blocks with 64 weights

* We are able to pass preprocessor macros to the Metal
  compiler
* Q6_K works and is actually slightly more efficient than
  the QK_K = 256 version (25.2 ms vs 25.8 ms)

* k_quants: WIP super-blocks with 64 weights

Q4_K works on Metal and is actually slightly faster
than QK_K = 256 (21.95 ms vs 24.0 ms).

* k_quants: WIP super-blocks with 64 weights

Q2_K works on Metal and is very slightly faster
than QK_K = 256 (23.8 ms vs 24.2 ms).

* k_quants: WIP super-blocks with 64 weights

Q3_K works on Metal and is slightly faster
than QK_K = 256 (26.6 ms vs 28.3 ms).

* k_quants: WIP super-blocks with 64 weights

Q5_K works on Metal and is slightly faster
than QK_K = 256 (23.7 ms vs 26.3 ms).

* k_quants: call them _K, not _k, also on Metal

* k_quants: correctly define QK_K in llama.cpp

* Fixed bug in q4_K quantization added with the 64-block addition

* Simplify via lambda

* k_quants: swicth Q3_K to 4-bit scales when QK_K = 64

Otherwise there isn't much benefit from this
quantization type. There is some very slight loss
in accuracy, but we reduce size by ~7%.
E.g., for OpenLLaMA-3B, Q3_K_S perplexity is
8.6131 with 8-bit scales and 8.6352 with 4-bit,
while file size decreases from 1.53G to 1.44G.

* k_quants: switch Q4_K to 4-bit scales when QK_K = 64

 Here the loss in accuracy is greater than for Q3_K,
 but the Q4_K points still move further to the left on
 the perplexity vs size curve.

* k_quants: forgot to add the Metal changes in last commit

* k_quants: change Q5_K to be type 0 when QK_K = 64

Still needs AVX2 implementation

* k_quants: AVX2 implementation for new 64-weight Q5_K

* k_quants: 10% faster ARM_NEON Q5_K dot product

* k_quants: fixed issue caused by merging with master

---------

Co-authored-by: Iwan Kawrakow <redacted>
CMakeLists.txt
Makefile
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
k_quants.c
k_quants.h
llama.cpp

index cc7560a7ae54edf6f5e9cce6f1d1b8f621ec0b56..ffda74a700beff20e26a279416499025ab4c2372 100644 (file)
@@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
 option(LLAMA_METAL                           "llama: use Metal"                                 OFF)
 option(LLAMA_K_QUANTS                        "llama: use k-quants"                              ON)
+option(LLAMA_QKK_64                          "llama: use super-block size of 64 for k-quants"   OFF)
 
 option(LLAMA_BUILD_TESTS                "llama: build tests"    ${LLAMA_STANDALONE})
 option(LLAMA_BUILD_EXAMPLES             "llama: build examples" ${LLAMA_STANDALONE})
@@ -225,6 +226,14 @@ if (LLAMA_BLAS)
     endif()
 endif()
 
+if (LLAMA_K_QUANTS)
+    set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
+    add_compile_definitions(GGML_USE_K_QUANTS)
+    if (LLAMA_QKK_64)
+        add_compile_definitions(GGML_QKK_64)
+    endif()
+endif()
+
 if (LLAMA_CUBLAS)
     cmake_minimum_required(VERSION 3.17)
 
@@ -289,11 +298,6 @@ if (LLAMA_METAL)
         )
 endif()
 
-if (LLAMA_K_QUANTS)
-    set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
-    add_compile_definitions(GGML_USE_K_QUANTS)
-endif()
-
 if (LLAMA_CLBLAST)
     find_package(CLBlast)
     if (CLBlast_FOUND)
index 5dd676fada41729424c0b89de53a1024ab7f1667..bda11791d5e23a2979a03afe0a8d5e5f8c354067 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -43,8 +43,11 @@ endif
 
 # keep standard at C11 and C++11
 # -Ofast tends to produce faster code, but may not be available for some compilers.
-#OPT = -Ofast
+ifdef LLAMA_FAST
+OPT = -Ofast
+else
 OPT = -O3
+endif
 CFLAGS   = -I.              $(OPT) -std=c11   -fPIC
 CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
 LDFLAGS  =
@@ -131,6 +134,10 @@ ifndef LLAMA_NO_K_QUANTS
        CFLAGS   += -DGGML_USE_K_QUANTS
        CXXFLAGS += -DGGML_USE_K_QUANTS
        OBJS     += k_quants.o
+ifdef LLAMA_QKK_64
+       CFLAGS   += -DGGML_QKK_64
+       CXXFLAGS += -DGGML_QKK_64
+endif
 endif
 
 ifndef LLAMA_NO_ACCELERATE
index 5e2fbc72442d5487c88d9b2df81217fb47e96c10..c34e96abfd08c105407164f244d11d6622af4420 100644 (file)
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 
 //================================= k-quants
 
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
 #define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
 typedef struct {
-    uint8_t hmask[QK_K/8];
-    uint8_t qs[QK_K/4]; // nibbles / quants
-    uint8_t scales[3*QK_K/64];
-    half d;
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+#ifdef GGML_QKK_64
+    uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
 } block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
 
+#ifdef GGML_QKK_64
+typedef struct {
+    half    d[2];              // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
 typedef struct {
     half d;                    // super-block scale for quantized scales
     half dmin;                 // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
 
+#ifdef GGML_QKK_64
 typedef struct {
-    half    d;                   // super-block scale for quantized scales
-    half    dmin;                // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64];   // scales, quantized with 6 bits
+    half d;                  // super-block scale
+    int8_t scales[QK_K/16];  // block scales
+    uint8_t qh[QK_K/8];      // quants, high bit
+    uint8_t qs[QK_K/2];      // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+    half d;               // super-block scale for quantized scales
+    half dmin;            // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
 } block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];   // quants, lower 4 bits
@@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
     const int i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
     const int tid = threadIdx.x;
+#if QK_K == 256
     const int n   = tid/32;
     const int l   = tid - 32*n;
     const int is  = 8*n + l/16;
 
-    const block_q2_K * x = (const block_q2_K *) vx;
-
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
@@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+    const int is = tid/16;  // 0 or 1
+    const int il = tid%16;  // 0...15
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    float * y = yy + i*QK_K + 16*is + il;
+    float dall = x[i].d;
+    float dmin = x[i].dmin;
+    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
 
 }
 
 static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
-    int r = threadIdx.x/4;
-    int i = blockIdx.x;
-    int tid = r/2;
-    int is0 = r%2;
-    int l0 = 16*is0 + 4*(threadIdx.x%4);
-    int n = tid / 4;
-    int j = tid - 4*n;
-
+    const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
 
+#if QK_K == 256
+    const int r = threadIdx.x/4;
+    const int tid = r/2;
+    const int is0 = r%2;
+    const int l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int n = tid / 4;
+    const int j = tid - 4*n;
+
     uint8_t m = 1 << (4*n + j);
     int is = 8*n + 2*j + is0;
     int shift = 2*j;
@@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
     const uint8_t * hm = x[i].hmask;
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+    const int tid = threadIdx.x;
+    const int is  = tid/16;  // 0 or 1
+    const int il  = tid%16;  // 0...15
+    const int im  = il/8;    // 0...1
+    const int in  = il%8;    // 0...7
+
+    float * y = yy + i*QK_K + 16*is + il;
+
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    const uint8_t h = x[i].hmask[in] >> (2*is + im);
+    const float   d = (float)x[i].d;
+
+    if (is == 0) {
+        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    } else {
+        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    }
+#endif
 
 }
 
+#if QK_K == 256
 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
         d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
         m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
     }
 }
+#endif
 
 static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
 
-    //// assume 64 threads - this is very slightly better than the one below
-    //const int tid = threadIdx.x;
-    //const int il  = tid/16;
-    //const int ir  = tid%16;
-    //const int is  = 2*il;
-    //const int n   = 2;
-
+#if QK_K == 256
     // assume 32 threads
     const int tid = threadIdx.x;
     const int il  = tid/8;
@@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
         y[l + 0] = d1 * (q[l] & 0xF) - m1;
         y[l +32] = d2 * (q[l] >>  4) - m2;
     }
+#else
+    const int tid = threadIdx.x;
+    const uint8_t * q = x[i].qs;
+    float * y = yy + i*QK_K;
+    const float d = (float)x[i].d[0];
+    const float m = (float)x[i].d[1];
+    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
+#endif
 }
 
 static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
 
     const int i = blockIdx.x;
 
+#if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
     const int il  = tid/16;   // il is in 0...3
@@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
     hm <<= 1;
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+    const int tid = threadIdx.x;
+    const uint8_t q = x[i].qs[tid];
+    const int im = tid/8;  // 0...3
+    const int in = tid%8;  // 0...7
+    const int is = tid/16; // 0 or 1
+    const uint8_t h = x[i].qh[in] >> im;
+    const float d = x[i].d;
+    float * y = yy + i*QK_K + tid;
+    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
 }
 
 static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
+#if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
@@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
     y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
     y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int ip  = tid/16;         // 0 or 1
+    const int il  = tid - 16*ip;    // 0...15
+
+    float * y = yy + i*QK_K + 16*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t   ql = x[i].ql[16*ip + il];
+    const uint8_t   qh = x[i].qh[il] >> (2*ip);
+    const int8_t  * sc = x[i].scales;
+
+    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
 }
 
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
 
     const block_q2_K * x = (const block_q2_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
     const int s_offset = 8*im;
     const int y_offset = 128*im + l0;
 
-    float tmp = 0; // partial sum for thread in warp
-
     uint32_t aux[4];
     const uint8_t * d = (const uint8_t *)aux;
     const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += dall * sum1 - dmin * sum2;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;
+
+    uint32_t uaux[2];
+    const uint8_t * d = (const uint8_t *)uaux;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint32_t * s = (const uint32_t *)x[i].scales;
+
+        uaux[0] = s[0] & 0x0f0f0f0f;
+        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+        const half2 * dh = (const half2 *)&x[i].d;
+
+        const float2 dall = __half22float2(dh[0]);
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t ql = q[l];
+            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+                  + y[l+16] * d[1] * ((ql >> 2) & 3)
+                  + y[l+32] * d[2] * ((ql >> 4) & 3)
+                  + y[l+48] * d[3] * ((ql >> 6) & 3);
+            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+        }
+        tmp += dall.x * sum1 - dall.y * sum2;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
@@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const block_q3_K * x = (const block_q3_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const uint16_t s_shift = 4*im;
 
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
         const float   * y  = yy + i * QK_K + y_offset;
@@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += d * sum;
 
     }
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
+    const int in = offset/8;                                 // 0 or 1
+    const int im = offset%8;                                 // 0...7
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint8_t * s = x[i].scales;
+
+        const float dall = (float)x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t hl = x[i].hmask[im+l] >> in;
+            const uint8_t ql = q[l];
+            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+        }
+        tmp += sum;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
     float tmp = 0; // partial sum for thread in warp
 
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
         tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    float tmp = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const float   * y = yy + i*QK_K + step;
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
+                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
+        }
+        tmp += sum;
+    }
+
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
 
 static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int row = blockIdx.x;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/2;  // 0...15
     const int ix  = threadIdx.x%2;
 
@@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
         const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
         tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+    }
 
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+    const int step = tid * K_QUANTS_PER_ITERATION;
+    const int im = step/8;
+    const int in = step%8;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const int8_t  * s = x[i].scales;
+        const float   * y = yy + i*QK_K + step;
+        const float     d = x[i].d;
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            const uint8_t h = x[i].qh[in+j] >> im;
+            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
+                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
+        }
+        tmp += sum;
     }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
@@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     const block_q6_K * x = (const block_q6_K *)vx + ib0;
 
+#if QK_K == 256
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
@@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     }
 
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0...3
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + step;
+        const uint8_t * ql = x[i].ql + step;
+        const uint8_t * qh = x[i].qh + step;
+        const int8_t  * s  = x[i].scales;
+
+        const float d = x[i+0].d;
+
+        float sum = 0;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
+                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
+        }
+        tmp += sum;
+
+    }
+
+#endif
+
     // sum up partial sums and write back result
     __syncthreads();
 #pragma unroll
@@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
index a7e104dc76fcafbe7bb405dfd235778eb79d0388..7551231b9cf32cf2705de66fce9a15f339943370 100644 (file)
@@ -51,21 +51,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
-    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q3_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q3_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
             exit(1);
         }
 
+#ifdef GGML_QKK_64
+        MTLCompileOptions* options = [MTLCompileOptions new];
+        options.preprocessorMacros = @{ @"QK_K" : @(64) };
+        ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+#else
         ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+#endif
         if (error) {
             fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
             exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
-        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q3_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q3_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         GGML_METAL_ADD_KERNEL(rope);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -662,7 +668,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
@@ -671,7 +677,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
@@ -680,7 +686,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
@@ -689,7 +695,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
@@ -698,7 +704,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
                                         {
@@ -750,11 +756,11 @@ void ggml_metal_graph_compute(
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
-                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
-                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
-                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
-                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
-                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
index d1e49222db2eb6c9c7614978fc88c349b86ed5e9..e62fe6842ea72bf950619935c74e9fecf1d10799 100644 (file)
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
 
 //============================================ k-quants ======================================================
 
+#ifndef QK_K
 #define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
     half d;           // super-block scale for quantized scales
     half dmin;        // super-block scale for quantized mins
-} block_q2_k;
+} block_q2_K;
 // 84 bytes / block
 
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
-    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
-    half d;                    // super-block scale
-} block_q3_k;
-// 110 bytes / block
-
+#if QK_K == 64
+    uint8_t scales[2];
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+    half    d[2];          // super-block scales/mins
+    uint8_t scales[2];
+    uint8_t qs[QK_K/2];    // 4-bit quants
+} block_q4_K;
+#else
 typedef struct {
     half d;             // super-block scale for quantized scales
     half dmin;          // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
     uint8_t qs[QK_K/2];        // 4--bit quants
-} block_q4_k;
-// 144 bytes / block
+} block_q4_K;
+#endif
 
+#if QK_K == 64
+typedef struct {
+    half  d;                     // super-block scales/mins
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+#else
 typedef struct {
     half d;                      // super-block scale for quantized scales
     half dmin;                   // super-block scale for quantized mins
     uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
-} block_q5_k;
+} block_q5_K;
 // 176 bytes / block
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];      // quants, lower 4 bits
     uint8_t qh[QK_K/4];      // quants, upper 2 bits
     int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
     half d;                  // super-block scale
-} block_q6_k;
+} block_q6_K;
 // 210 bytes / block
 
 static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //========================================== dequantization =============================
 
-static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
 
         device const uint8_t * q = x[i].qs;
 
+#if QK_K == 256
         int is = 0;
         float dl, ml;
         for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
             }
             q += 32;
         }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
+            y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
+            y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
+            y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
+
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
             }
             q += 32;
         }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
 
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
     }
+#endif
 
 }
 
-static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
-
     for (int i = 0; i < nb; i++) {
 
+        device const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
         const float d = x[i].d;
         const float min = x[i].dmin;
 
-        device const uint8_t * q = x[i].qs;
         device const uint8_t * scales = x[i].scales;
 
         int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
             for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
             q += 32; is += 2;
         }
+#else
+        device const uint8_t * s = x[i].scales;
+        device const half2 * dh = (device const half2 *)x[i].d;
+        const float2 d = (float2)dh[0];
+        const float d1 = d[0] * (s[0] & 0xF);
+        const float d2 = d[0] * (s[1] & 0xF);
+        const float m1 = d[1] * (s[0] >>  4);
+        const float m2 = d[1] * (s[1] >>  4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
    for (int i = 0; i < nb; i++) {
 
         const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
             u1 <<= 2; u2 <<= 2;
         }
     }
+#else
+    for (int i = 0; i < nb; i++) {
+
+        const float d = (float)x[i].d;
+
+        device const uint8_t * ql = x[i].qs;
+        device const uint8_t * qh = x[i].qh;
+        device const int8_t  * sc = x[i].scales;
+
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * sc[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * sc[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * sc[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * sc[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+    }
+#endif
 
 }
 
-static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
 
         const float d = x[i].d;
 
+#if QK_K == 256
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
                 int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
             qh += 32;
             sc += 8;
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
     }
 }
 
-kernel void kernel_get_rows_q2_k(
+kernel void kernel_get_rows_q2_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q2_k(
-            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q2_K(
+            (device const block_q2_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q3_k(
+kernel void kernel_get_rows_q3_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q3_k(
-            (device const block_q3_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q3_K(
+            (device const block_q3_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q4_k(
+kernel void kernel_get_rows_q4_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q4_k(
-            (device const block_q4_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q4_K(
+            (device const block_q4_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q5_k(
+kernel void kernel_get_rows_q5_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q5_k(
-            (device const block_q5_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q5_K(
+            (device const block_q5_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q6_k(
+kernel void kernel_get_rows_q6_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q6_k(
-            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q6_K(
+            (device const block_q6_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mat_q2_k_f32(
+kernel void kernel_mul_mat_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     const int tid = tpitg.y;    // 0...16
     const int il  = tid/4;      // 0...3
     const int ir  = tid%4;      // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int y_offset = 64*il + n*ir;
     const int q_offset = 32*ip + n*ir;
 
-    sum[ith] = 0.0f;
-
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
 
         device const float   * y = yy + i*QK_K + y_offset;
 
-        //float4 s = {0.f, 0.f, 0.f, 0.f};
         float2 s = {0.f, 0.f};
         float smin = 0;
         for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
         sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
 
     }
-    sum[ith] = sumf;
+#else
+    const int il = 4 * tpitg.x;
 
-    //int mask1 = (ith%4 == 0);
-    //int mask2 = (ith%16 == 0);
+    uint32_t aux[2];
+    thread const uint8_t * d = (thread const uint8_t *)aux;
+    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
 
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //if (ith == 0) {
-    //    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i*QK_K + il;
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        device const uint32_t * a = (device const uint32_t *)x[i].scales;
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
+                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
+                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
+                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+        }
+    }
+#endif
+
+    sum[ith] = sumf;
 
     //
     // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
     //
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
     }
 }
 
-kernel void kernel_mul_mat_q3_k_f32(
+kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+#if QK_K == 256
+
+    const uint8_t m3 = 3;
+    const int8_t  m4 = 4;
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = tpitg.y;        // expecting 16
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
 
     //sum[ith] = sumf;
     sum[ith] = sumf1 - 32.f*sumf2;
+#else
+    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
+
+    float sumf = 0;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].hmask + in;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hm = h[l] >> im;
+            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+        }
+
+    }
+
+    sum[ith] = sumf;
+
+#endif
 
     //
     // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q4_k_f32(
+kernel void kernel_mul_mat_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
-    sum[ith] = 0.0f;
-
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    uint16_t aux16[2];
+    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+
+    const int il  = 4*tpitg.x;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
+                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+        }
+    }
+#endif
 
     sum[ith] = sumf;
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
     //}
 }
 
-kernel void kernel_mul_mat_q5_k_f32(
+kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
 
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im  = il/8;         // 0, 0, 1, 1
+    const int in  = il%8;         // 0, 4, 0, 4
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d = (float)x[i].d;
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].qh + in;
+        device const int8_t  * s = x[i].scales;
+        device const float   * y = yy + i*QK_K + il;
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hl = h[l] >> im;
+            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
+                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
+                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
+                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        }
+    }
+#endif
     sum[ith] = sumf;
 
     //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q6_k_f32(
+kernel void kernel_mul_mat_q6_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
     const int iqs  = 16 * tpitg.y;
     const int ip   = iqs / 128;         // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+#else
+    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+        device const float * y = yy + i * QK_K + il;
+        device const uint8_t * ql = x[i].ql + il;
+        device const uint8_t * qh = x[i].qh + il;
+        device const int8_t  * s  = x[i].scales;
+
+        const float d = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 4; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >>  4) | ((qh[l] & kmask3) >> 0)) - 32);
+            sums[3] += y[l+48] * ((int8_t)((ql[l+16] >>  4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+        sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
+    }
+
+#endif
 
     sum[ith] = sumf;
 
index a48c821710cbb206069a01ad1a48cb65e7dade4f..46dd884b0aa8094e420d504e53b75d49bb6e853f 100644 (file)
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
     return scale;
 }
 
+#if QK_K == 256
 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
     if (j < 4) {
         *d = q[j] & 63; *m = q[j + 4] & 63;
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
         *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
     }
 }
+#endif
 
 //========================- 2-bit (de)-quantization
 
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
             }
         }
 
+#if QK_K == 256
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
             }
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
 
         x += QK_K;
 
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
 
         const uint8_t * q = x[i].qs;
 
+#if QK_K == 256
         int is = 0;
         float dl, ml;
         for (int n = 0; n < QK_K; n += 128) {
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
             }
             q += 32;
         }
-
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
+            y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
+            y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
+            y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
+        }
+        y += QK_K;
+#endif
     }
 }
 
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
             }
         }
 
+#if QK_K == 256
         memset(y[i].scales, 0, 12);
         if (max_scale) {
             float iscale = -32.f/max_scale;
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
                 L[16*j + ii] = l + 4;
             }
         }
+#else
+        if (max_scale) {
+            float iscale = -8.f/max_scale;
+            for (int j = 0; j < QK_K/16; j+=2) {
+                int l1 = nearest_int(iscale*scales[j]);
+                l1 = 8 + MAX(-8, MIN(7, l1));
+                int l2 = nearest_int(iscale*scales[j+1]);
+                l2 = 8 + MAX(-8, MIN(7, l2));
+                y[i].scales[j/2] = l1 | (l2 << 4);
+            }
+            y[i].d = ggml_fp32_to_fp16(1/iscale);
+        } else {
+            for (int j = 0; j < QK_K/16; j+=2) {
+                y[i].scales[j/2] = 0;
+            }
+            y[i].d = ggml_fp32_to_fp16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
+            float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#endif
 
         memset(y[i].hmask, 0, QK_K/8);
-        // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
         int m = 0;
         uint8_t hm = 1;
         for (int j = 0; j < QK_K; ++j) {
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
                 m = 0; hm <<= 1;
             }
         }
+#if QK_K == 256
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
             }
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
 
         x += QK_K;
     }
 }
 
+#if QK_K == 256
 void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
     assert(k % QK_K == 0);
-    assert(QK_K == 256);
     const int nb = k / QK_K;
 
     const uint32_t kmask1 = 0x03030303;
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
 
     }
 }
+#else
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    assert(QK_K == 64);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l=0; l<8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
+    }
+}
+#endif
 
 void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
     quantize_row_q3_K_reference(x, vy, k);
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
             }
         }
 
+#if QK_K == 256
         float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
         float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
         for (int j = 0; j < QK_K/32; ++j) {
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
                 L[32*j + ii] = l;
             }
         }
+#else
+        const float s_factor = 15.f;
+        float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? s_factor/max_min   : 0.f;
+        int d1 = nearest_int(inv_scale*scales[0]);
+        int m1 = nearest_int(inv_min*mins[0]);
+        int d2 = nearest_int(inv_scale*scales[1]);
+        int m2 = nearest_int(inv_min*mins[1]);
+        y[i].scales[0] = d1 | (m1 << 4);
+        y[i].scales[1] = d2 | (m2 << 4);
+        y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
+        y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
+
+        float sumlx = 0;
+        int   suml2 = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            const uint8_t sd = y[i].scales[j] & 0xF;
+            const uint8_t sm = y[i].scales[j] >>  4;
+            const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
+            if (!d) continue;
+            const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + m)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+                sumlx += (x[32*j + ii] + m)*l*sd;
+                suml2 += l*l*sd*sd;
+            }
+        }
+        if (suml2) {
+            y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
+        }
+#endif
         uint8_t * q = y[i].qs;
         for (int j = 0; j < QK_K; j += 64) {
-            for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
         }
 
         x += QK_K;
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
 
     for (int i = 0; i < nb; i++) {
 
-        const float d = ggml_fp16_to_fp32(x[i].d);
-        const float min = ggml_fp16_to_fp32(x[i].dmin);
-
         const uint8_t * q = x[i].qs;
 
+#if QK_K == 256
+
+        const float d   = ggml_fp16_to_fp32(x[i].d);
+        const float min = ggml_fp16_to_fp32(x[i].dmin);
+
         int is = 0;
         uint8_t sc, m;
         for (int j = 0; j < QK_K; j += 64) {
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
             for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
             q += 32; is += 2;
         }
+#else
+        const float dall = ggml_fp16_to_fp32(x[i].d[0]);
+        const float mall = ggml_fp16_to_fp32(x[i].d[1]);
+        const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
+        const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
 
     }
 }
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
     uint8_t L[QK_K];
     float mins[QK_K/32];
     float scales[QK_K/32];
+#else
+    int8_t L[QK_K];
+    float scales[QK_K/16];
+#endif
 
     for (int i = 0; i < nb; i++) {
 
+#if QK_K == 256
+
         float max_scale = 0; // as we are deducting the min, scales are always positive
         float max_min = 0;
         for (int j = 0; j < QK_K/32; ++j) {
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
             m1 <<= 2; m2 <<= 2;
             ql += 32;
         }
+#else
+        float max_scale = 0, amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
+            float abs_scale = fabsf(scales[j]);
+            if (abs_scale > amax) {
+                amax = abs_scale;
+                max_scale = scales[j];
+            }
+        }
+
+        float iscale = -128.f/max_scale;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int l = nearest_int(iscale*scales[j]);
+            y[i].scales[j] = MAX(-128, MIN(127, l));
+        }
+        y[i].d = ggml_fp32_to_fp16(1/iscale);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
+            if (!d) continue;
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-16, MIN(15, l));
+                L[16*j + ii] = l + 16;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        for (int j = 0; j < 32; ++j) {
+            int jm = j%8;
+            int is = j/8;
+            int l1 = L[j];
+            if (l1 > 15) {
+                l1 -= 16; qh[jm] |= (1 << is);
+            }
+            int l2 = L[j + 32];
+            if (l2 > 15) {
+                l2 -= 16; qh[jm] |= (1 << (4 + is));
+            }
+            ql[j] = l1 | (l2 << 4);
+        }
+#endif
 
         x += QK_K;
 
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
 
     for (int i = 0; i < nb; i++) {
 
-        const float d = ggml_fp16_to_fp32(x[i].d);
-        const float min = ggml_fp16_to_fp32(x[i].dmin);
-
         const uint8_t * ql = x[i].qs;
         const uint8_t * qh = x[i].qh;
 
+#if QK_K == 256
+
+        const float d = ggml_fp16_to_fp32(x[i].d);
+        const float min = ggml_fp16_to_fp32(x[i].dmin);
+
         int is = 0;
         uint8_t sc, m;
         uint8_t u1 = 1, u2 = 2;
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
             ql += 32; is += 2;
             u1 <<= 2; u2 <<= 2;
         }
+#else
+        float d = ggml_fp16_to_fp32(x[i].d);
+        const int8_t * restrict s = x[i].scales;
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * s[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * s[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * s[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * s[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+#endif
     }
 }
 
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
 
         uint8_t * restrict ql = y[i].ql;
         uint8_t * restrict qh = y[i].qh;
+#if QK_K == 256
         for (int j = 0; j < QK_K; j += 128) {
             for (int l = 0; l < 32; ++l) {
                 const uint8_t q1 = L[j + l +  0] & 0xF;
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
             ql += 64;
             qh += 32;
         }
+#else
+        for (int l = 0; l < 32; ++l) {
+            const uint8_t q1 = L[l +  0] & 0xF;
+            const uint8_t q2 = L[l + 32] & 0xF;
+            ql[l] = q1 | (q2 << 4);
+        }
+        for (int l = 0; l < 16; ++l) {
+            qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
+        }
+#endif
 
         x += QK_K;
 
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict sc = x[i].scales;
 
+#if QK_K == 256
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
                 int is = l/16;
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
             qh += 32;
             sc += 8;
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
 
     }
 }
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
 }
 #endif
 
+#if QK_K == 256
 void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
 
     const block_q2_K * restrict x = vx;
@@ -1201,6 +1436,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 }
 
+#else
+
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+
+    int8x16x4_t q2bytes;
+
+    uint32_t aux32[2];
+    const uint8_t * scales = (const uint8_t *)aux32;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const float dmin = -y[i].d * (float)x[i].dmin;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+        aux32[0] = sc[0] & 0x0f0f0f0f;
+        aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+        int isum1 = 0, isum2 = 0;
+
+        const uint8x16_t q2bits = vld1q_u8(q2);
+
+        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
+        q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
+        q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum1 += vaddvq_s16(p1) * scales[0];
+        isum2 += vaddvq_s16(p2) * scales[1];
+
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum1 += vaddvq_s16(p3) * scales[2];
+        isum2 += vaddvq_s16(p4) * scales[3];
+#endif
+        sum += d * (isum1 + isum2);
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
+        const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+        const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+
+        const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
+        const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
+        const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
+        const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#else
+
+    float sumf = 0;
+
+    int isum[4];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+        isum[0] = isum[1] = isum[2] = isum[3] = 0;
+        for (int l =  0; l < 16; ++l) {
+            isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
+            isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
+            isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
+            isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
+        }
+        for (int l = 0; l < 4; ++l) {
+            isum[l] *= (sc[l] & 0xF);
+        }
+        sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
 void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     assert(n % QK_K == 0);
 
@@ -1501,71 +1898,271 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
 }
 
-void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+#else
+
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     assert(n % QK_K == 0);
 
-    const block_q4_K * restrict x = vx;
+    const block_q3_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
 
     const int nb = n / QK_K;
 
-    static const uint32_t kmask1 = 0x3f3f3f3f;
-    static const uint32_t kmask2 = 0x0f0f0f0f;
-    static const uint32_t kmask3 = 0x03030303;
-
-    uint32_t utmp[4];
-
 #ifdef __ARM_NEON
 
-    const uint8x16_t m4b = vdupq_n_u8(0xf);
 #ifdef __ARM_FEATURE_DOTPROD
-    const int32x4_t mzero = vdupq_n_s32(0);
+    const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x2_t q4bytes;
-    int8x16x2_t q8bytes;
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+    const uint8x16_t mh  = vdupq_n_u8(4);
 
-    float sumf = 0;
+    int8x16x4_t q3bytes;
 
-    for (int i = 0; i < nb; ++i) {
+    uint16_t aux16[2];
+    int8_t * scales = (int8_t *)aux16;
 
-        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
-        const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+    float sum = 0;
 
-        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+    for (int i = 0; i < nb; ++i) {
 
-        memcpy(utmp, x[i].scales, 12);
+        uint8x16x4_t q3h;
 
-        const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[0] &= kmask1;
+        const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
+        const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
+        const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
 
-        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
-        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
-                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
-        sumf -= dmin * vaddvq_s32(prod);
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
 
-        const uint8_t * scales = (const uint8_t *)utmp;
+        for (int j = 0; j < 4; ++j) scales[j] -= 8;
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
+        int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
 
-        //int32x4_t isum = mzero;
+        const float d = y[i].d * (float)x[i].d;
 
-        int32_t sumi1 = 0;
-        int32_t sumi2 = 0;
+        const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
+        q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
+        q3h.val[1] = vandq_u8(mh, htmp);
+        q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
+        q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
 
-        for (int j = 0; j < QK_K/64; ++j) {
+        q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b),                q3h.val[0]));
+        q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
+        q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
+        q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6),                q3h.val[3]));
 
-            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
+#endif
 
-#ifdef __ARM_FEATURE_DOTPROD
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
-            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
-            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+        sum += d * isum;
 
-            const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
-            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i m1 = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
+        const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
+        __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
+        q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
+        q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m256i q3aux  = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
+        const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
+        const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+        const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        // multiply with scales
+        p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+        p16_0 = _mm256_add_epi32(p16_0, p16_1);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    int32_t scales[4];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 8; ++l) {
+            a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
+            a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
+            a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
+            a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
+            a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
+            a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
+            a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
+            a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
+        }
+
+        scales[0] = (x[i].scales[0] & 0xF) - 8;
+        scales[1] = (x[i].scales[0] >>  4) - 8;
+        scales[2] = (x[i].scales[1] & 0xF) - 8;
+        scales[3] = (x[i].scales[1] >>  4) - 8;
+
+        memset(aux32, 0, 8*sizeof(int32_t));
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
+        }
+        const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x2_t q4bytes;
+    int8x16x2_t q8bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+
+        const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        //int32x4_t isum = mzero;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+
+#ifdef __ARM_FEATURE_DOTPROD
+            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+            const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
 
             q8bytes = vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
@@ -1614,9 +2211,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
         const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
         const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
 
-        const uint8_t * restrict q4 = x[i].qs;
-        const int8_t  * restrict q8 = y[i].qs;
-
         memcpy(utmp, x[i].scales, 12);
         utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
         const uint32_t uaux = utmp[1] & kmask1;
@@ -1624,6 +2218,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
         const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
 
         const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
@@ -1726,7 +2323,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
     *s = sumf;
 #endif
 }
+#else
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    float sumf = 0;
+
+    int8x16x2_t q4bytes;
+    int8x16x4_t q8bytes;
+
+    float sum_mins = 0.f;
+
+    uint16_t aux16[2];
+    const uint8_t * restrict scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t * restrict a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
+        sum_mins += y[i].d * (float)x[i].d[1] * summi;
+
+        const float d = y[i].d * (float)x[i].d[0];
 
+        const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+
+#ifdef __ARM_FEATURE_DOTPROD
+        q8bytes = vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+        const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+        const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+        const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
+        const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
+
+#else
+        q8bytes = vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
+        int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
+
+#endif
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf - sum_mins;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
+        const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m256i q4l = _mm256_and_si256(q4bits, m4);
+        const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+        const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+        const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+
+        const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
+
+        const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
+#else
+
+    uint8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint16_t s16[2];
+    const uint8_t * restrict scales = (const uint8_t *)s16;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        uint8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
+        for (int l = 0; l < 32; ++l) a[l+32] = q4[l]  >> 4;
+
+        const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+        s16[0] = b[0] & 0x0f0f;
+        s16[1] = (b[0] >> 4) & 0x0f0f;
+
+        sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 16; a += 16;
+            for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 16; a += 16;
+            const float dl = d * scales[j];
+            for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
 void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     assert(n % QK_K == 0);
 
@@ -1840,18 +2606,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
    for (int i = 0; i < nb; ++i) {
 
-        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
-        const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
-
         const uint8_t * restrict q5 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
+#if QK_K == 256
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
         memcpy(utmp, x[i].scales, 12);
         utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
         const uint32_t uaux = utmp[1] & kmask1;
         utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
         utmp[2] = uaux;
         utmp[0] &= kmask1;
+#else
+        // TODO
+        const float d = 0, dmin = 0;
+#endif
 
         const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
 
@@ -1972,8 +2743,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 }
 
+#else
+
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const int32x4_t mzero = vdupq_n_s32(0);
+    const uint8x16_t mh = vdupq_n_u8(16);
+
+    int8x16x4_t q5bytes;
+    uint8x16x4_t q5h;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const int8_t * sc = x[i].scales;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint8x8_t qhbits = vld1_u8(qh);
+
+        const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
+        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
+        q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
+        q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
+        q5h.val[2] = vbicq_u8(mh, htmp);
+        q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
+
+        q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
+        q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
+        q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
+        q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
+        int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
+        int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
+        int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
+
+        sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+#else
+
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
+
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
+
+        sumf += d*sumi;
+#endif
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
+        const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
+
+        const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
+        const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
+
+        const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+        const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
+        const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
+        const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
+        const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
+
+        const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+
+    uint8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        uint8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) {
+            a[l+ 0] = q4[l] & 0xF;
+            a[l+32] = q4[l]  >> 4;
+        }
+        for (int is = 0; is < 8; ++is) {
+            uint8_t m = 1 << is;
+            for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
+        }
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const int8_t * restrict sc = x[i].scales;
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float dl = d * sc[j];
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l <  8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
+            q8 += 16; a += 16;
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
 
 
+#if QK_K == 256
 void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     assert(n % QK_K == 0);
 
@@ -2242,3 +3174,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
     *s = sumf;
 #endif
 }
+
+#else
+
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+    const int8x16_t  m32s = vdupq_n_s8(32);
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    int8x16x4_t q6bytes;
+    uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = (float)x[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        int32_t isum = 0;
+
+        uint8x16_t   qhbits = vld1q_u8(qh);
+        uint8x16x2_t q6bits = vld1q_u8_x2(q6);
+        int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
+        uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
+        q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 4);
+        q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 6);
+        q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+        q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+        q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+        q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
+        q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+#else
+
+        int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+
+        int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+
+        sum += isum * d_all * y[i].d;
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
+        const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
+
+        const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+        const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+        __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+
+        sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 16; ++l) {
+            a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            a[l+32] = (int8_t)((q4[l+ 0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            a[l+48] = (int8_t)((q4[l+16] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+        }
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#endif
index 10a0baac73a078f459ce0c21302bf79bd796cec2..6abe3d7b8f42201f655388c29360cf721619746e 100644 (file)
@@ -7,7 +7,13 @@
 #include <stddef.h>
 
 // Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
 #define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
 
 //
 // Super-block quantization structures
@@ -29,38 +35,67 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
 // weight is represented as x = a * q
 // 16 blocks of 16 elemenets each
 // Effectively 3.4375 bits per weight
+#ifdef GGML_QKK_64
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
-    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    uint8_t scales[2];
     ggml_fp16_t d;             // super-block scale
 } block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
+#else
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[12];        // scales, quantized with 6 bits
+    ggml_fp16_t d;             // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+#endif
 
 // 4-bit quantization
 // 16 blocks of 32 elements each
 // weight is represented as x = a * q + b
 // Effectively 4.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d[2];          // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
 typedef struct {
     ggml_fp16_t d;             // super-block scale for quantized scales
     ggml_fp16_t dmin;          // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
-static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+#endif
 
 // 5-bit quantization
 // 16 blocks of 32 elements each
 // weight is represented as x = a * q + b
 // Effectively 5.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d;               // super-block scale
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
 typedef struct {
     ggml_fp16_t d;               // super-block scale for quantized scales
     ggml_fp16_t dmin;            // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
 } block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
 
 // 6-bit quantization
 // weight is represented as x = a * q
index ac22a48f8ab971781b2bf05ce9d9616634a4160b..c41c2a8a3299224f394c92a72d9aeebf8a5e4023 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #endif
 #ifdef GGML_USE_K_QUANTS
 #ifndef QK_K
+#ifdef GGML_QKK_64
+#define QK_K 64
+#else
 #define QK_K 256
 #endif
 #endif
+#endif
 
 #include <array>
 #include <ctime>
@@ -2470,6 +2474,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     std::vector<std::thread> workers;
     std::mutex mutex;
 
+    auto use_more_bits = [] (int i_layer, int num_layers) -> bool {
+        return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2;
+    };
+
     size_t idx = 0;
     for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) {
         llama_buffer read_data;
@@ -2524,15 +2532,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
                 else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
-                         (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8 ||
-                         (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K;
+                        use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
+                else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
+                        (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
                 ++i_attention_wv;
             } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) {
                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
                 else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
-                         (i_feed_forward_w2 < n_feed_forward_w2/8 || i_feed_forward_w2 >= 7*n_feed_forward_w2/8 ||
-                         (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K;
+                         use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
+                //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
                 ++i_feed_forward_w2;
             } else if (tensor.name.find("attention.wo.weight") != std::string::npos) {
                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;