]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl : support k-quants (#1836)
author0cc4m <redacted>
Fri, 16 Jun 2023 18:59:49 +0000 (20:59 +0200)
committerGitHub <redacted>
Fri, 16 Jun 2023 18:59:49 +0000 (21:59 +0300)
* Porting q2_k kernel to OpenCL

* Set global and local sizes for kernel calls for dequantizing k-quants

* Added q6_k kernel

* Fix q4_k opencl struct order

* Replace uchar with uint8_t

* Finish dequant kernels

* Added OpenCL DMMV kernels

* Fix q2_k, improve code

* Fix q3_k

* Shorten switch statements

* Improve code formatting

---------

Co-authored-by: Concedo <redacted>
ggml-opencl.cpp

index 5df922abd720e062bd27297a066c088df51df072..1d4db96ee9b6117ede0a35bcdd81ff0cc7a7e37b 100644 (file)
@@ -15,7 +15,7 @@
 
 #include "ggml.h"
 
-#define CL_DMMV_BLOCK_SIZE 32;
+#define CL_DMMV_BLOCK_SIZE 32
 
 #define MULTILINE_QUOTE(...) #__VA_ARGS__
 static std::string program_source = MULTILINE_QUOTE(
@@ -59,6 +59,46 @@ struct __attribute__ ((packed)) block_q8_0
     int8_t qs[QK8_0];
 };
 
+struct __attribute__((packed)) block_q2_K
+{
+    uint8_t scales[16];
+    uint8_t qs[64];
+    half d;
+    half dmin;
+};
+
+struct __attribute__((packed)) block_q3_K
+{
+    uint8_t hmask[32];
+    uint8_t qs[64];
+    uint8_t scales[12];
+    half d;
+};
+
+struct __attribute__((packed)) block_q4_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qs[128];
+};
+
+struct __attribute__((packed)) block_q5_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qh[32];
+    uint8_t qs[128];
+};
+
+struct __attribute__((packed)) block_q6_K
+{
+    uint8_t ql[128];
+    uint8_t qh[64];
+    int8_t scales[16];
+    half d;
+};
 
 __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
     const uint i = get_global_id(0);
@@ -131,8 +171,314 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
     *v0 = vload_half(0, &x[ib + 0]);
     *v1 = vload_half(0, &x[ib + 1]);
 }
+
+inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
+{
+    if (j < 4)
+    {
+        *d = q[j] & 63;
+        *m = q[j + 4] & 63;
+    }
+    else
+    {
+        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
+    }
+}
+
+__kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int n = tid / 32;
+    const int l = tid - 32 * n;
+    const int is = 8 * n + l / 16;
+
+    const uint8_t q = x[i].qs[32 * n + l];
+    __global float *y = yy + i * 256 + 128 * n;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
+    y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
+    y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
+    y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
+}
+
+__kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
+{
+    int r = get_local_id(0) / 4;
+    int i = get_group_id(0);
+    int tid = r / 2;
+    int is0 = r % 2;
+    int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
+    int n = tid / 4;
+    int j = tid - 4 * n;
+
+    uint8_t m = 1 << (4 * n + j);
+    int is = 8 * n + 2 * j + is0;
+    int shift = 2 * j;
+
+    int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
+              : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
+              : is < 12  ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
+              : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
+    float d_all = vload_half(0, &x[i].d);
+    float dl = d_all * (us - 32);
+
+    __global float *y = yy + i * 256 + 128 * n + 32 * j;
+    const __global uint8_t *q = x[i].qs + 32 * n;
+    const __global uint8_t *hm = x[i].hmask;
+
+    for (int l = l0; l < l0 + 4; ++l)
+        y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+__kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int il = tid / 8;
+    const int ir = tid % 8;
+    const int is = 2 * il;
+    const int n = 4;
+
+    __global float *y = yy + i * 256 + 64 * il + n * ir;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    __global const uint8_t *q = x[i].qs + 32 * il + n * ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+    float d1 = dall * sc;
+    float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+    float d2 = dall * sc;
+    float m2 = dmin * m;
+    for (int l = 0; l < n; ++l)
+    {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l + 32] = d2 * (q[l] >> 4) - m2;
+    }
+}
+
+__kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int il = tid / 16;
+    const int ir = tid % 16;
+    const int is = 2 * il;
+
+    __global float *y = yy + i * 256 + 64 * il + 2 * ir;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    __global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir;
+    __global const uint8_t *qh = x[i].qh + 2 * ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    uint8_t hm = 1 << (2 * il);
+    y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1;
+    y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2;
+}
+
+__kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int ip = tid / 32;
+    const int il = tid - 32 * ip;
+    const int is = 8 * ip + il / 16;
+
+    __global float *y = yy + i * 256 + 128 * ip + il;
+
+    const float d = vload_half(0, &x[i].d);
+
+    __global const uint8_t *ql = x[i].ql + 64 * ip + il;
+    const uint8_t qh = x[i].qh[32 * ip + il];
+    __global const int8_t *sc = x[i].scales + is;
+
+    y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+
+void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    int n = iqs / 128;
+    int r = iqs - 128 * n;
+    int l = r / 8;
+
+    __global const float *y = yy + 128 * n + l;
+    __global const uint8_t *q = x[ib].qs + 32 * n + l;
+    __global const uint8_t *s = x[ib].scales + 8 * n;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
+              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
+
+    *result = sum;
+}
+
+void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    int n = iqs/128;
+    int r = iqs - 128*n;
+    int l = r/8;
+
+    __global const float   * y = yy + 128*n + l;
+    __global const uint8_t * q = x[ib].qs + 32*n + l;
+    __global const uint8_t * hm = x[ib].hmask + l;
+    const int8_t * s = (const int8_t *)utmp + 8*n;
+
+    aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
+    aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
+    aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
+
+    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+    const float dall = vload_half(0, &x[ib].d);
+    const uint8_t m = 1 << (4*n);
+
+    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
+              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
+              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
+              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
+              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
+              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
+              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
+              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
+
+    *result = sum * dall;
+
+}
+
+void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const int j  = iqs / 64;        // j  is in 0...3
+    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
+    const int is = 2*j;             // is is in 0...6 in steps of 2
+
+    __global const float   * y = yy + 64*j + ir;
+    __global const uint8_t * q = x[ib].qs + 32*j + ir;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    float sum = 0;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
+        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
+    }
+
+    *result = sum;
+}
+
+void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const int j  = iqs / 64;
+    const int ir = (iqs - 64*j)/2;
+    const int is = 2*j;
+
+    __global const float   * y  = yy + 64*j + ir;
+    __global const uint8_t * ql = x[ib].qs + 32*j + ir;
+    __global const uint8_t * qh = x[ib].qh + ir;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    uint8_t hm  = 1 << is;
+    float sum = 0;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
+    }
+    hm <<= 1;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
+    }
+    *result = sum;
+
+}
+
+void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+
+    const int ip = iqs / 128;        // 0 or 1
+    const int il = (iqs - 128*ip)/8; // 0...15
+    const int is = 8*ip;
+
+    __global const float * y = yy + 128*ip + il;
+
+    const float d = vload_half(0, &x[ib].d);
+
+    __global const uint8_t * ql = x[ib].ql + 64*ip + il;
+    __global const uint8_t * qh = x[ib].qh + 32*ip + il;
+    __global const int8_t  * sc = x[ib].scales + is;
+
+    *result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
+           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+
+}
+
 );
 
+
 std::string dequant_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
     const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
@@ -160,7 +506,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
 std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
     const int block_size = get_local_size(0);
-    const int row = get_global_id(0) / block_size;
+    const int row = get_group_id(0);
     const int tid = get_local_id(0);
 
     const uint qk = QUANT_K;
@@ -199,6 +545,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
 }
 );
 
+std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
+__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
+    const int block_size = get_local_size(0);
+    const int row = get_group_id(0);
+    const int tid = get_local_id(0);
+
+    const int iter_stride = 256;
+    const int vals_per_iter = iter_stride / block_size;
+    const int num_blocks_per_row = ncols / 256;
+    const int ib0 = row*num_blocks_per_row;
+
+    tmp[tid] = 0;
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = ib0 + col/256; // x block index
+        const int iqs = col%256; // x quant index
+        const int iybs = col - col%256; // y block start index
+
+        // dequantize
+        float v;
+        DOT_KERNEL(x, ib, iqs, y + iybs, &v);
+        tmp[tid] += v;
+    }
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=block_size/2; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+);
+
 std::string mul_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
     const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
@@ -260,6 +645,18 @@ std::array<std::string, 2> mul_str_values = {
     "mul_f32", "float"
 };
 
+std::array<std::string, 3> dmmv_k_str_keys = {
+    "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
+};
+
+std::array<std::string, 15> dmmv_k_str_values = {
+    "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
+    "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
+    "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
+    "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
+    "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
+};
+
 std::string& replace(std::string& s, const std::string& from, const std::string& to) {
     size_t pos = 0;
     while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -289,6 +686,14 @@ std::string generate_kernels() {
         }
         src << mul_kernel << '\n';
     }
+    for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
+        std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
+        for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
+            replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
+        }
+        src << dmmv_k_kernel << '\n';
+    }
+
     return src.str();
 }
 
@@ -300,6 +705,8 @@ static cl_program program;
 static cl_kernel convert_row_f16_cl;
 static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
 static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
+static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
+static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
 static cl_kernel mul_f32_cl;
 static bool fp16_support;
 
@@ -529,6 +936,12 @@ void ggml_cl_init(void) {
     CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
     CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
     CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
+    CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
+    CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
+    CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
+    CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));
+    CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err));
+    CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err));
 
     // dequant mul mat kernel
     CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
@@ -537,6 +950,11 @@ void ggml_cl_init(void) {
     CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
     CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
     CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
 
     // mul kernel
     CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
@@ -554,6 +972,16 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
             return &dequantize_row_q5_1_cl;
         case GGML_TYPE_Q8_0:
             return &dequantize_row_q8_0_cl;
+        case GGML_TYPE_Q2_K:
+            return &dequantize_block_q2_k_cl;
+        case GGML_TYPE_Q3_K:
+            return &dequantize_block_q3_k_cl;
+        case GGML_TYPE_Q4_K:
+            return &dequantize_block_q4_k_cl;
+        case GGML_TYPE_Q5_K:
+            return &dequantize_block_q5_k_cl;
+        case GGML_TYPE_Q6_K:
+            return &dequantize_block_q6_k_cl;
         case GGML_TYPE_F16:
             return &convert_row_f16_cl;
         default:
@@ -561,6 +989,50 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
     }
 }
 
+static size_t ggml_cl_global_denom(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 1;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+            return 4;
+        case GGML_TYPE_Q4_K:
+            return 8;
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return 4;
+        case GGML_TYPE_F16:
+        default:
+            return 1;
+    }
+}
+
+static size_t ggml_cl_local_size(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 0;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+            return 64;
+        case GGML_TYPE_Q4_K:
+            return 32;
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return 64;
+        case GGML_TYPE_F16:
+        default:
+            return 0;
+    }
+}
+
 static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
@@ -575,6 +1047,16 @@ static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
             return &dequantize_mul_mat_vec_q8_0_cl;
         case GGML_TYPE_F16:
             return &convert_mul_mat_vec_f16_cl;
+        case GGML_TYPE_Q2_K:
+            return &dequantize_mul_mat_vec_q2_K_cl;
+        case GGML_TYPE_Q3_K:
+            return &dequantize_mul_mat_vec_q3_K_cl;
+        case GGML_TYPE_Q4_K:
+            return &dequantize_mul_mat_vec_q4_K_cl;
+        case GGML_TYPE_Q5_K:
+            return &dequantize_mul_mat_vec_q5_K_cl;
+        case GGML_TYPE_Q6_K:
+            return &dequantize_mul_mat_vec_q6_K_cl;
         default:
             return nullptr;
     }
@@ -1017,6 +1499,9 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
     cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
     GGML_ASSERT(to_fp32_cl != nullptr);
 
+    const size_t global_denom = ggml_cl_global_denom(type);
+    const size_t local = ggml_cl_local_size(type);
+
     size_t ev_idx = 0;
     std::vector<cl_event> events;
 
@@ -1049,10 +1534,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
                 CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
             } else { // general dequantization kernel + CLBlast matrix matrix multiplication
                 // convert src0 to fp32 on device
-                const size_t global = x_ne;
+                const size_t global = x_ne / global_denom;
                 CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
                 CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
-                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
+                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
 
                 // copy src1 to device
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));