]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Porting the improved K-Quant CUDA kernels to OpenCL (#1966)
authorLostRuins <redacted>
Thu, 29 Jun 2023 03:56:43 +0000 (11:56 +0800)
committerGitHub <redacted>
Thu, 29 Jun 2023 03:56:43 +0000 (05:56 +0200)
* Added broken new q4k quant

* xx + ib0

* Fix q2_k fast kernel

* Use preprocessor for QK_K

* Add q6_k fast matmul kernel

* ported q3k speedup successfully

* ported q2k and q5k speedups

* remove old dot kernels and template

* fixed global const struct types

* fixing address spaces

* fixed string too long CI issue

---------

Co-authored-by: 0cc4m <redacted>
ggml-opencl.cpp

index 95f4cec6dd59cdc7b7ece51549420e355dfe3607..fed4ffb0ccd0538b56aa552023581974b3fe2332 100644 (file)
 
 #define CL_DMMV_BLOCK_SIZE 32
 
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 1
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
 #define MULTILINE_QUOTE(...) #__VA_ARGS__
 static std::string program_source = MULTILINE_QUOTE(
 
 typedef char int8_t;
 typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
 typedef int int32_t;
 typedef uint uint32_t;
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
     *v0 = vload_half(0, &x[ib + 0]);
     *v1 = vload_half(0, &x[ib + 1]);
 }
+);
 
+static std::string k_quants_source = MULTILINE_QUOTE(
 inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
 {
     if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
     const int is = 8 * n + l / 16;
 
     const uint8_t q = x[i].qs[32 * n + l];
-    __global float *y = yy + i * 256 + 128 * n;
+    __global float *y = yy + i * QK_K + 128 * n;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
     float d_all = vload_half(0, &x[i].d);
     float dl = d_all * (us - 32);
 
-    __global float *y = yy + i * 256 + 128 * n + 32 * j;
+    __global float *y = yy + i * QK_K + 128 * n + 32 * j;
     const __global uint8_t *q = x[i].qs + 32 * n;
     const __global uint8_t *hm = x[i].hmask;
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
     const int is = 2 * il;
     const int n = 4;
 
-    __global float *y = yy + i * 256 + 64 * il + n * ir;
+    __global float *y = yy + i * QK_K + 64 * il + n * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
     const int ir = tid % 16;
     const int is = 2 * il;
 
-    __global float *y = yy + i * 256 + 64 * il + 2 * ir;
+    __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     const int il = tid - 32 * ip;
     const int is = 8 * ip + il / 16;
 
-    __global float *y = yy + i * 256 + 128 * ip + il;
+    __global float *y = yy + i * QK_K + 128 * ip + il;
 
     const float d = vload_half(0, &x[i].d);
 
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
 }
 
+__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int row = get_group_id(0);
 
-    int n = iqs / 128;
-    int r = iqs - 128 * n;
-    int l = r / 8;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    __global const float *y = yy + 128 * n + l;
-    __global const uint8_t *q = x[ib].qs + 32 * n + l;
-    __global const uint8_t *s = x[ib].scales + 8 * n;
+    __global const struct block_q2_K * x = xx + ib0;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
-              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
-              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
-              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
-              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
-              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
-              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
-              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
+    const int step = 16/K_QUANTS_PER_ITERATION;
 
-    *result = sum;
-}
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
 
-void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
 
-    const uint32_t kmask1 = 0x03030303;
-    const uint32_t kmask2 = 0x0f0f0f0f;
+    tmp[16 * ix + tid] = 0;
 
-    uint32_t aux[3];
-    uint32_t utmp[4];
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
 
-    int n = iqs/128;
-    int r = iqs - 128*n;
-    int l = r/8;
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
-    __global const float   * y = yy + 128*n + l;
-    __global const uint8_t * q = x[ib].qs + 32*n + l;
-    __global const uint8_t * hm = x[ib].hmask + l;
-    const int8_t * s = (const int8_t *)utmp + 8*n;
+        __global const float   * y = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
 
-    aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
-    aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
-    aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
 
-    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
-    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
-    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
-    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+        __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const uint8_t m = 1 << (4*n);
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 
-    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
-              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
-              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
-              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
-              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
-              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
-              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
-              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
 
-    *result = sum * dall;
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int row = get_group_id(0);
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int j  = iqs / 64;        // j  is in 0...3
-    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
-    const int is = 2*j;             // is is in 0...6 in steps of 2
+    __global const struct block_q3_K * x = xx + ib0;
 
-    __global const float   * y = yy + 64*j + ir;
-    __global const uint8_t * q = x[ib].qs + 32*j + ir;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
+        __global const uint8_t * h = x[i].hmask + l0;
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = vload_half(0, &x[i].d);
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += d * sum;
 
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
-        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
     }
 
-    *result = sum;
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-    const int j  = iqs / 64;
-    const int ir = (iqs - 64*j)/2;
-    const int is = 2*j;
+    //to rename it later, just to test now
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
 
-    __global const float   * y  = yy + 64*j + ir;
-    __global const uint8_t * ql = x[ib].qs + 32*j + ir;
-    __global const uint8_t * qh = x[ib].qh + ir;
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const int step = 8/K_QUANTS_PER_ITERATION;
+
+    const int il  = tid/step;     // 0...3
+    const int ir  = tid - step*il;// 0...3
+    const int n   = 2*K_QUANTS_PER_ITERATION;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q4_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const uint8_t * q1 = x[i].qs + q_offset;
+        __global const uint8_t * q2 = q1 + 64;
+        __global const float   * y1 = yy + i*QK_K + y_offset;
+        __global const float   * y2 = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
-    uint8_t hm  = 1 << is;
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
     }
-    hm <<= 1;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+
+__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = get_local_id(0)/2;  // 0...15
+    const int ix  = get_local_id(0)%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q5_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        __global const uint8_t * ql1 = x[i].qs + q_offset;
+        __global const uint8_t * ql2 = ql1 + 64;
+        __global const uint8_t * qh  = x[i].qh + l0;
+        __global const float   * y1  = yy + i*QK_K + y_offset;
+        __global const float   * y2  = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
     }
-    *result = sum;
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
+
+    const int row = get_group_id(0);
 
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int ip = iqs / 128;        // 0 or 1
-    const int il = (iqs - 128*ip)/8; // 0...15
-    const int is = 8*ip;
+    __global const struct block_q6_K * x = xx + ib0;
 
-    __global const float * y = yy + 128*ip + il;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
-    const float d = vload_half(0, &x[ib].d);
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
 
-    __global const uint8_t * ql = x[ib].ql + 64*ip + il;
-    __global const uint8_t * qh = x[ib].qh + 32*ip + il;
-    __global const int8_t  * sc = x[ib].scales + is;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    tmp[16 * ix + tid] = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * ql = x[i].ql + ql_offset;
+        __global const uint8_t * qh = x[i].qh + qh_offset;
+        __global const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = vload_half(0, &x[i].d);
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp[16 * ix + tid] += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp[16 * ix + tid] += sum;
+#endif
 
-    *result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
-           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
-           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
-           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
-           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
-           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
-           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
-           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
 );
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
 }
 );
 
-std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
-__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
-    const int block_size = get_local_size(0);
-    const int row = get_group_id(0);
-    const int tid = get_local_id(0);
-
-    const int iter_stride = 256;
-    const int vals_per_iter = iter_stride / block_size;
-    const int num_blocks_per_row = ncols / 256;
-    const int ib0 = row*num_blocks_per_row;
-
-    tmp[tid] = 0;
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = ib0 + col/256; // x block index
-        const int iqs = col%256; // x quant index
-        const int iybs = col - col%256; // y block start index
-
-        // dequantize
-        float v;
-        DOT_KERNEL(x, ib, iqs, y + iybs, &v);
-        tmp[tid] += v;
-    }
-
-    // sum up partial sums and write back result
-    barrier(CLK_LOCAL_MEM_FENCE);
-    for (int s=block_size/2; s>0; s>>=1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if (tid == 0) {
-        dst[row] = tmp[0];
-    }
-}
-);
 
 std::string mul_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
     "mul_f32", "float"
 };
 
-std::array<std::string, 3> dmmv_k_str_keys = {
-    "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
-};
-
-std::array<std::string, 15> dmmv_k_str_values = {
-    "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
-    "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
-    "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
-    "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
-    "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
-};
-
 std::string& replace(std::string& s, const std::string& from, const std::string& to) {
     size_t pos = 0;
     while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
 std::string generate_kernels() {
     std::stringstream src;
     src << program_source << '\n';
+    src << k_quants_source << '\n';
     for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
         std::string dequant_kernel = dequant_template;
         std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +873,6 @@ std::string generate_kernels() {
         }
         src << mul_kernel << '\n';
     }
-    for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
-        std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
-        for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
-            replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
-        }
-        src << dmmv_k_kernel << '\n';
-    }
 
     return src.str();
 }
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
         exit(1);
     }
 
-    const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
-                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
+    std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
+                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
+                               "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
 
-    err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
+    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
     if(err < 0) {
 
         clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);