]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA : faster k-quant dot kernels (#1862)
authorKawrakow <redacted>
Fri, 16 Jun 2023 17:08:44 +0000 (20:08 +0300)
committerGitHub <redacted>
Fri, 16 Jun 2023 17:08:44 +0000 (20:08 +0300)
* cuda : faster k-quant dot kernels

* Imrove Q2_K dot kernel on older GPUs

We now have a K_QUANTS_PER_ITERATION macro, which should be
set to 1 on older and to 2 on newer GPUs.
With this, we preserve the performance of the original
PR on RTX-4080, and are faster compared to master on
GTX-1660.

* Imrove Q6_K dot kernel on older GPUs

Using the same K_QUANTS_PER_ITERATION macro as last commit,
we preserve performance on RTX-4080 and speed up
Q6_K on a GTX-1660.

* Add LLAMA_CUDA_KQUANTS_ITER to CMakeLists.txt and Makefile

Allowed values are 1 or 2. 2 gives the best performance on
modern GPUs and is set as default. On older GPUs 1 may work
better.

* PR comments

---------

Co-authored-by: Iwan Kawrakow <redacted>
CMakeLists.txt
Makefile
ggml-cuda.cu

index ea9f80b803c1d05a5b95de29dc6144d42497147d..dbbc0b5d3b8201e10d2916c181e0a54f86892fdc 100644 (file)
@@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
 option(LLAMA_CUBLAS                          "llama: use cuBLAS"                                OFF)
 set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
 set(LLAMA_CUDA_DMMV_Y       "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
+set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
 option(LLAMA_METAL                           "llama: use Metal"                                 OFF)
 option(LLAMA_K_QUANTS                        "llama: use k-quants"                              ON)
@@ -201,6 +202,7 @@ if (LLAMA_CUBLAS)
         add_compile_definitions(GGML_USE_CUBLAS)
         add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
         add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
+        add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
 
         if (LLAMA_STATIC)
             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
index 09c8834f5830d9cfe716ab0a49ad96c50bccae92..b24caf8dd60e52b0f57b5a4562e1dd32648eb181 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -171,6 +171,11 @@ ifdef LLAMA_CUDA_DMMV_Y
 else
        NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
 endif # LLAMA_CUDA_DMMV_Y
+ifdef LLAMA_CUDA_KQUANTS_ITER
+       NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
+else
+       NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
+endif
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
 endif # LLAMA_CUBLAS
index bd89d0a1f79d967e4bff6e6aac6f6c7064e65aa0..7edd1a9f8ef0a95220ebd2e3d16821825b403e50 100644 (file)
@@ -167,6 +167,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define GGML_CUDA_DMMV_Y 1
 #endif
 
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -326,37 +332,6 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
 }
 
-static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
-
-    const block_q2_K * x = (const block_q2_K *) vx;
-
-    // if n is 0, we want to do the lower 128, else the upper 128,
-    // covering y[l+0],  y[l+32], y[l+64], y[l+96] and
-    //          y[l+16], y[l+48], y[l+80], y[l+112]
-    int n = iqs/128;                // 0 or 1
-    int r = iqs - 128*n;            // 0...120 in steps of 8
-    int l = r/8;                    // 0...15 in steps of 1
-
-    const float   * y = yy + 128*n + l;
-    const uint8_t * q = x[ib].qs + 32*n + l;
-    const uint8_t * s = x[ib].scales + 8*n;
-
-    const float dall = x[ib].d;
-    const float dmin = x[ib].dmin;
-
-    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
-              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
-              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
-              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
-              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
-              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
-              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
-              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
-
-    result = sum;
-
-}
-
 static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
     int r = threadIdx.x/4;
@@ -388,51 +363,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
 }
 
-static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
-
-    const block_q3_K * x = (const block_q3_K *) vx;
-
-    const uint32_t kmask1 = 0x03030303;
-    const uint32_t kmask2 = 0x0f0f0f0f;
-
-    uint32_t aux[3];
-    uint32_t utmp[4];
-
-    // if n is 0, we want to do the lower 128, else the upper 128,
-    // covering y[l+0],  y[l+32], y[l+64], y[l+96] and
-    //          y[l+16], y[l+48], y[l+80], y[l+112]
-    int n = iqs/128;                // 0 or 1
-    int r = iqs - 128*n;            // 0...120 in steps of 8
-    int l = r/8;                    // 0...15 in steps of 1
-
-    const float   * y = yy + 128*n + l;
-    const uint8_t * q = x[ib].qs + 32*n + l;
-    const uint8_t * hm = x[ib].hmask + l;
-    const int8_t  * s = (const int8_t *)utmp + 8*n;
-
-    memcpy(aux, x[ib].scales, 12);
-    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
-    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
-    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
-    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
-
-    const float dall = x[ib].d;
-
-    const uint8_t m = 1 << (4*n);
-
-    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
-              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
-              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
-              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
-              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
-              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
-              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
-              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
-
-    result = sum * dall;
-
-}
-
 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
         d = q[j] & 63; m = q[j + 4] & 63;
@@ -479,38 +409,6 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
     }
 }
 
-static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
-
-    const block_q4_K * x = (const block_q4_K *) vx;
-
-                                    // iqs is in 0...248 in steps of 8 =>
-    const int j  = iqs / 64;        // j  is in 0...3
-    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
-    const int is = 2*j;             // is is in 0...6 in steps of 2
-
-    const float   * y = yy + 64*j + ir;
-    const uint8_t * q = x[ib].qs + 32*j + ir;
-
-    const float dall = x[ib].d;
-    const float dmin = x[ib].dmin;
-
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, sc, m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, sc, m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
-
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
-        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
-    }
-    result = sum;
-
-}
-
 static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
@@ -544,43 +442,6 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 }
 
-static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
-
-    const block_q5_K * x = (const block_q5_K *) vx;
-
-                                    // iqs is in 0...248 in steps of 8 =>
-    const int j  = iqs / 64;        // j  is in 0...3
-    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
-    const int is = 2*j;             // is is in 0...6 in steps of 2
-
-    const float   * y  = yy + 64*j + ir;
-    const uint8_t * ql = x[ib].qs + 32*j + ir;
-    const uint8_t * qh = x[ib].qh + ir;
-
-    const float dall = x[ib].d;
-    const float dmin = x[ib].dmin;
-
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, sc, m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, sc, m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
-
-    uint8_t   hm  = 1 << is;
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
-    }
-    hm <<= 1;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
-    }
-    result = sum;
-
-}
-
 static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
@@ -606,31 +467,376 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
 }
 
-static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const block_q6_K * x = (const block_q6_K *) vx;
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int ip = iqs / 128;        // 0 or 1
-    const int il = (iqs - 128*ip)/8; // 0...15
-    const int is = 8*ip;
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
 
-    const float * y = yy + 128*ip + il;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const float d = x[ib].d;
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0
+
+    const int step = 16/K_QUANTS_PER_ITERATION;
+
+    const int im = tid/step;      // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im; // 0...7
+
+    const int l0 = K_QUANTS_PER_ITERATION*in;        // 0...14 in steps of 4
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+
+        }
+        tmp += dall * sum1 - dmin * sum2;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;  // 0, 1
+
+    const int n  = 2;           // iterations in the inner loop
+    const int im = tid/8;       // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - 8*im;  // 0...7
+
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;        // 0...28 in steps of 4
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
 
-    const uint8_t * ql = x[ib].ql + 64*ip + il;
-    const uint8_t * qh = x[ib].qh + 32*ip + il;
-    const int8_t  * sc = x[ib].scales + is;
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
 
-    result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
-           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
-           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
-           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
-           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
-           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
-           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
-           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+        const uint8_t * q1 = x[i].qs + q_offset;
+        const uint8_t * q2 = q1 + 64;
+        const float   * y1 = yy + i*QK_K + y_offset;
+        const float   * y2 = y1 + 128;
 
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * ql2 = ql1 + 64;
+        const uint8_t * qh  = x[i].qh + l0;
+        const float   * y1  = yy + i*QK_K + y_offset;
+        const float   * y2  = y1 + 128;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l] >>  4) + (qh[l] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l] >>  4) + (qh[l] & (hm2 << 1) ? 16 : 0));
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
 }
 
 static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
@@ -712,46 +918,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
     }
 }
 
-template <int n_thread, dot_kernel_k_t dot_kernel>
-static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int tid = threadIdx.x;
-
-    const int iter_stride = QK_K;
-    const int vals_per_iter = iter_stride / n_thread;
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = ib0 + col/QK_K; // x block index
-        const int iqs = col%QK_K; // x quant index
-        const int iybs = col - col%QK_K; // y block start index
-
-        float v;
-        dot_kernel(vx, ib, iqs, y + iybs, v);
-        tmp += v;
-    }
-
-    // sum up partial sums and write back result
-    __syncthreads();
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
-
-    if (tid == 0) {
-        dst[row] = tmp;
-    }
-}
-
 static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
     const half * x = (half *) vx;
 
@@ -1094,43 +1260,34 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
     const int block_num_y = (nrows + ny - 1) / ny;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2;
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {