]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync latest llama.cpp (ggml_task_type changes + GPU backends)
authorGeorgi Gerganov <redacted>
Sun, 2 Jul 2023 14:33:57 +0000 (17:33 +0300)
committerGeorgi Gerganov <redacted>
Sun, 2 Jul 2023 14:33:57 +0000 (17:33 +0300)
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-metal.m
src/ggml-opencl.cpp
src/ggml.c

index 459913222e0833e0a124be41efd73fbbdd124a9b..11b51f8bd656a39b79317c16d8346d999389c182 100644 (file)
@@ -444,6 +444,9 @@ extern "C" {
 
 
     // compute types
+
+    // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
+    // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
     enum ggml_task_type {
         GGML_TASK_INIT = 0,
         GGML_TASK_COMPUTE,
index c34e96abfd08c105407164f244d11d6622af4420..50df20edd7a7b211324e0af1c18e970bc250251d 100644 (file)
@@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+};
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -223,6 +228,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] + y[i];
 }
 
+static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = __hadd(x[i], __float2half(y[i]));
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1235,7 +1249,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
 }
 
 static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
-    const half * x = (half *) vx;
+    const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1283,9 +1297,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
+    const int row_stride_x, const int channel_stride_x) {
 
-    const half * x = (half *) vx;
+    const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1328,14 +1342,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 }
 
 static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (float *) cxi;
+    const float * xi = (const float *) cxi;
     float * dsti = (float *) cdsti;
 
     *dsti = *xi;
 }
 
 static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
-    const float * xi = (float *) cxi;
+    const float * xi = (const float *) cxi;
     half * dsti = (half *) cdsti;
 
     *dsti = __float2half(*xi);
@@ -1459,6 +1473,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
     add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
 }
 
+static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
 static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1684,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
     const dim3 block_nums(1, nrows_x, nchannels_x);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
 }
 
 static void ggml_cpy_f32_f32_cuda(
@@ -1941,7 +1960,7 @@ inline void ggml_cuda_op_add(
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
     cudaStream_t & cudaStream_main){
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
@@ -1949,8 +1968,13 @@ inline void ggml_cuda_op_add(
     const int64_t i01_diff = i01_high - i01_low;
 
     // compute
-    add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else {
+        GGML_ASSERT(false);
+    }
 
     (void) src1;
     (void) dst;
@@ -1982,7 +2006,6 @@ inline void ggml_cuda_op_mul(
 
         // compute
         mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-        CUDA_CHECK(cudaGetLastError());
     }
 
     (void) dst;
@@ -2003,7 +2026,6 @@ inline void ggml_cuda_op_silu(
 
     // compute
     silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2026,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm(
 
     // compute
     rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2105,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
             GGML_ASSERT(false);
             break;
     }
-    CUDA_CHECK(cudaGetLastError());
 
 #ifdef GGML_CUDA_DMMV_F16
     if (src1_convert_f16) {
@@ -2182,7 +2202,6 @@ inline void ggml_cuda_op_rope(
 
     // compute
     rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2206,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf(
 
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2228,7 +2246,6 @@ inline void ggml_cuda_op_soft_max(
 
     // compute
     soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2324,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
     size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
 
-    // if multiple GPUs are used they need to wait for the main GPU to finish
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signifies that the main device has finished calculating the input data
     if (split && g_device_count > 1) {
         CUDA_CHECK(cudaSetDevice(g_main_device));
-        CUDA_CHECK(cudaDeviceSynchronize());
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
     }
 
     for (int id = 0; id < g_device_count; ++id) {
@@ -2353,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         int64_t row_diff = row_high - row_low;
 
         cudaSetDevice(id);
+        cudaStream_t cudaStream_main = g_cudaStreams_main[id];
+
+        // wait for main GPU data if necessary
+        if (split && id != g_main_device) {
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
+        }
 
         if (src0_on_device && src0_is_contiguous) {
             if (src0_is_f32) {
@@ -2428,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 }
                 const int64_t i11 = i13*ne12 + i12;
 
-                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
                 float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2489,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 
                 // do the computation
                 op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+                CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
                 if (!dst_on_device) {
@@ -2518,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
                     }
                 }
+
+                // signify to main device that other device is done
+                if (split && g_device_count > 1 && id != g_main_device) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+                }
             }
         }
     }
@@ -2529,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         }
 
         CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaDeviceSynchronize());
 
         if (src0_asq[id] > 0) {
             ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
@@ -2544,11 +2571,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
         }
     }
+
+    // main device waits for all other devices to be finished
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        for (int id = 0; id < g_device_count; ++id) {
+            if (id != g_main_device) {
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+            }
+        }
+    }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
 }
 
 void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
+    // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
+    // Due to flatten_rows == true this does in practice not make a difference however.
+    // Better solution would be nice but right now that would require disproportionate changes.
+    GGML_ASSERT(
+        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
+        src1->type == GGML_TYPE_F32 &&
+        (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
 }
 
 void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2777,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
 
         extra->data_device[id] = buf;
+
+        if (backend == GGML_BACKEND_GPU_SPLIT) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+        }
     }
 
     tensor->extra = extra;
@@ -2790,18 +2842,21 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     for (int id = 0; id < g_device_count; ++id) {
-        if (extra->data_device[id] == nullptr) {
-            continue;
+        if (extra->data_device[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
-        CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaFree(extra->data_device[id]));
+        if (extra->events[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+        }
     }
 
     delete extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
@@ -2810,11 +2865,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
     if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src0->op;
         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
-            ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+            ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
         }
     }
     if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
-        ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
+        ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
     }
 
     tensor->backend = GGML_BACKEND_GPU;
@@ -2822,11 +2877,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
     memset(extra, 0, sizeof(*extra));
 
     const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
-        tensor->op == GGML_OP_VIEW;
+        tensor->op == GGML_OP_VIEW ||
+        force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
+    if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
@@ -2865,11 +2921,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
 }
 
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, true);
+    ggml_cuda_assign_buffers_impl(tensor, true, false);
 }
 
 void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false);
+    ggml_cuda_assign_buffers_impl(tensor, false, false);
+}
+
+void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, false, true);
 }
 
 void ggml_cuda_set_main_device(int main_device) {
index d32b4484267ab5a3f9028db07330302432c1fd93..3c1e8deb6a6ddbfbb43e22bd01cc64e560fe313d 100644 (file)
@@ -8,10 +8,6 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
-struct ggml_tensor_extra_gpu {
-    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
-};
-
 void   ggml_init_cublas(void);
 void   ggml_cuda_set_tensor_split(const float * tensor_split);
 
@@ -29,6 +25,7 @@ void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 void   ggml_cuda_free_data(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
 void   ggml_cuda_set_scratch_size(size_t scratch_size);
 void   ggml_cuda_free_scratch(void);
index 7551231b9cf32cf2705de66fce9a15f339943370..fd69c41fe357d6e1636f2640186c7908352450b9 100644 (file)
@@ -202,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) {
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
     fprintf(stderr, "%s: deallocating\n", __func__);
-
+    for (int i = 0; i < ctx->n_buffers; ++i) {
+        [ctx->buffers[i].metal release];
+    }
     free(ctx);
 }
 
index 95f4cec6dd59cdc7b7ece51549420e355dfe3607..fed4ffb0ccd0538b56aa552023581974b3fe2332 100644 (file)
 
 #define CL_DMMV_BLOCK_SIZE 32
 
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 1
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
 #define MULTILINE_QUOTE(...) #__VA_ARGS__
 static std::string program_source = MULTILINE_QUOTE(
 
 typedef char int8_t;
 typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
 typedef int int32_t;
 typedef uint uint32_t;
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
     *v0 = vload_half(0, &x[ib + 0]);
     *v1 = vload_half(0, &x[ib + 1]);
 }
+);
 
+static std::string k_quants_source = MULTILINE_QUOTE(
 inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
 {
     if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
     const int is = 8 * n + l / 16;
 
     const uint8_t q = x[i].qs[32 * n + l];
-    __global float *y = yy + i * 256 + 128 * n;
+    __global float *y = yy + i * QK_K + 128 * n;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
     float d_all = vload_half(0, &x[i].d);
     float dl = d_all * (us - 32);
 
-    __global float *y = yy + i * 256 + 128 * n + 32 * j;
+    __global float *y = yy + i * QK_K + 128 * n + 32 * j;
     const __global uint8_t *q = x[i].qs + 32 * n;
     const __global uint8_t *hm = x[i].hmask;
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
     const int is = 2 * il;
     const int n = 4;
 
-    __global float *y = yy + i * 256 + 64 * il + n * ir;
+    __global float *y = yy + i * QK_K + 64 * il + n * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
     const int ir = tid % 16;
     const int is = 2 * il;
 
-    __global float *y = yy + i * 256 + 64 * il + 2 * ir;
+    __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     const int il = tid - 32 * ip;
     const int is = 8 * ip + il / 16;
 
-    __global float *y = yy + i * 256 + 128 * ip + il;
+    __global float *y = yy + i * QK_K + 128 * ip + il;
 
     const float d = vload_half(0, &x[i].d);
 
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
 }
 
+__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int row = get_group_id(0);
 
-    int n = iqs / 128;
-    int r = iqs - 128 * n;
-    int l = r / 8;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    __global const float *y = yy + 128 * n + l;
-    __global const uint8_t *q = x[ib].qs + 32 * n + l;
-    __global const uint8_t *s = x[ib].scales + 8 * n;
+    __global const struct block_q2_K * x = xx + ib0;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
-              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
-              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
-              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
-              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
-              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
-              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
-              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
+    const int step = 16/K_QUANTS_PER_ITERATION;
 
-    *result = sum;
-}
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
 
-void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
 
-    const uint32_t kmask1 = 0x03030303;
-    const uint32_t kmask2 = 0x0f0f0f0f;
+    tmp[16 * ix + tid] = 0;
 
-    uint32_t aux[3];
-    uint32_t utmp[4];
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
 
-    int n = iqs/128;
-    int r = iqs - 128*n;
-    int l = r/8;
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
-    __global const float   * y = yy + 128*n + l;
-    __global const uint8_t * q = x[ib].qs + 32*n + l;
-    __global const uint8_t * hm = x[ib].hmask + l;
-    const int8_t * s = (const int8_t *)utmp + 8*n;
+        __global const float   * y = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
 
-    aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
-    aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
-    aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
 
-    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
-    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
-    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
-    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+        __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const uint8_t m = 1 << (4*n);
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 
-    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
-              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
-              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
-              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
-              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
-              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
-              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
-              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
 
-    *result = sum * dall;
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int row = get_group_id(0);
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int j  = iqs / 64;        // j  is in 0...3
-    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
-    const int is = 2*j;             // is is in 0...6 in steps of 2
+    __global const struct block_q3_K * x = xx + ib0;
 
-    __global const float   * y = yy + 64*j + ir;
-    __global const uint8_t * q = x[ib].qs + 32*j + ir;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
+        __global const uint8_t * h = x[i].hmask + l0;
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = vload_half(0, &x[i].d);
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += d * sum;
 
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
-        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
     }
 
-    *result = sum;
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-    const int j  = iqs / 64;
-    const int ir = (iqs - 64*j)/2;
-    const int is = 2*j;
+    //to rename it later, just to test now
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
 
-    __global const float   * y  = yy + 64*j + ir;
-    __global const uint8_t * ql = x[ib].qs + 32*j + ir;
-    __global const uint8_t * qh = x[ib].qh + ir;
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const int step = 8/K_QUANTS_PER_ITERATION;
+
+    const int il  = tid/step;     // 0...3
+    const int ir  = tid - step*il;// 0...3
+    const int n   = 2*K_QUANTS_PER_ITERATION;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q4_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const uint8_t * q1 = x[i].qs + q_offset;
+        __global const uint8_t * q2 = q1 + 64;
+        __global const float   * y1 = yy + i*QK_K + y_offset;
+        __global const float   * y2 = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
-    uint8_t hm  = 1 << is;
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
     }
-    hm <<= 1;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+
+__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = get_local_id(0)/2;  // 0...15
+    const int ix  = get_local_id(0)%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q5_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        __global const uint8_t * ql1 = x[i].qs + q_offset;
+        __global const uint8_t * ql2 = ql1 + 64;
+        __global const uint8_t * qh  = x[i].qh + l0;
+        __global const float   * y1  = yy + i*QK_K + y_offset;
+        __global const float   * y2  = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
     }
-    *result = sum;
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
+
+    const int row = get_group_id(0);
 
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int ip = iqs / 128;        // 0 or 1
-    const int il = (iqs - 128*ip)/8; // 0...15
-    const int is = 8*ip;
+    __global const struct block_q6_K * x = xx + ib0;
 
-    __global const float * y = yy + 128*ip + il;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
-    const float d = vload_half(0, &x[ib].d);
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
 
-    __global const uint8_t * ql = x[ib].ql + 64*ip + il;
-    __global const uint8_t * qh = x[ib].qh + 32*ip + il;
-    __global const int8_t  * sc = x[ib].scales + is;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    tmp[16 * ix + tid] = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * ql = x[i].ql + ql_offset;
+        __global const uint8_t * qh = x[i].qh + qh_offset;
+        __global const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = vload_half(0, &x[i].d);
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp[16 * ix + tid] += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp[16 * ix + tid] += sum;
+#endif
 
-    *result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
-           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
-           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
-           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
-           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
-           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
-           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
-           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
 );
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
 }
 );
 
-std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
-__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
-    const int block_size = get_local_size(0);
-    const int row = get_group_id(0);
-    const int tid = get_local_id(0);
-
-    const int iter_stride = 256;
-    const int vals_per_iter = iter_stride / block_size;
-    const int num_blocks_per_row = ncols / 256;
-    const int ib0 = row*num_blocks_per_row;
-
-    tmp[tid] = 0;
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = ib0 + col/256; // x block index
-        const int iqs = col%256; // x quant index
-        const int iybs = col - col%256; // y block start index
-
-        // dequantize
-        float v;
-        DOT_KERNEL(x, ib, iqs, y + iybs, &v);
-        tmp[tid] += v;
-    }
-
-    // sum up partial sums and write back result
-    barrier(CLK_LOCAL_MEM_FENCE);
-    for (int s=block_size/2; s>0; s>>=1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if (tid == 0) {
-        dst[row] = tmp[0];
-    }
-}
-);
 
 std::string mul_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
     "mul_f32", "float"
 };
 
-std::array<std::string, 3> dmmv_k_str_keys = {
-    "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
-};
-
-std::array<std::string, 15> dmmv_k_str_values = {
-    "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
-    "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
-    "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
-    "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
-    "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
-};
-
 std::string& replace(std::string& s, const std::string& from, const std::string& to) {
     size_t pos = 0;
     while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
 std::string generate_kernels() {
     std::stringstream src;
     src << program_source << '\n';
+    src << k_quants_source << '\n';
     for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
         std::string dequant_kernel = dequant_template;
         std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +873,6 @@ std::string generate_kernels() {
         }
         src << mul_kernel << '\n';
     }
-    for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
-        std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
-        for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
-            replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
-        }
-        src << dmmv_k_kernel << '\n';
-    }
 
     return src.str();
 }
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
         exit(1);
     }
 
-    const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
-                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
+    std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
+                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
+                               "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
 
-    err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
+    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
     if(err < 0) {
 
         clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
index 92faf03f746a1c4203f6ab98725ad3660b3e2298..afeb72ff00ccfc3872849823d6c62e2f43d40be9 100644 (file)
@@ -3846,6 +3846,41 @@ static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
+// WARN:
+// Mis-confguration can lead to problem that's hard to reason about:
+// * At best  it crash or talks nosense.
+// * At worst it talks slightly difference but hard to perceive.
+//
+// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
+// Take care about compile options (e.g., GGML_USE_xxx).
+static bool GGML_OP_HAS_INIT    [GGML_OP_COUNT] = { 0 };
+static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 };
+
+static void ggml_setup_op_has_task_pass(void) {
+    {   // INIT
+        bool * p = GGML_OP_HAS_INIT;
+
+        p[GGML_OP_ACC                    ] = true;
+        p[GGML_OP_MUL_MAT                ] = true;
+        p[GGML_OP_OUT_PROD               ] = true;
+        p[GGML_OP_SET                    ] = true;
+        p[GGML_OP_GET_ROWS_BACK          ] = true;
+        p[GGML_OP_DIAG_MASK_INF          ] = true;
+        p[GGML_OP_DIAG_MASK_ZERO         ] = true;
+        p[GGML_OP_CONV_1D_S1_PH          ] = true;
+        p[GGML_OP_CONV_1D_S2_PH          ] = true;
+        p[GGML_OP_CONV_2D_SK_P0          ] = true;
+        p[GGML_OP_FLASH_ATTN_BACK        ] = true;
+        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
+    }
+
+    {   // FINALIZE
+        bool * p = GGML_OP_HAS_FINALIZE;
+
+        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
+    }
+}
+
 //
 // ggml context
 //
@@ -4267,6 +4302,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         ggml_cl_init();
 #endif
 
+        ggml_setup_op_has_task_pass();
+
         is_first_call = false;
     }
 
@@ -16684,7 +16721,8 @@ typedef pthread_t ggml_thread_t;
 
 #endif
 
-#ifdef __linux__
+// Android's libc implementation "bionic" does not support setting affinity
+#if defined(__linux__) && !defined(__BIONIC__)
 void set_numa_thread_affinity(int thread_n, int n_threads) {
     if (!ggml_is_numa()) {
         return;
@@ -16790,9 +16828,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             if (node_n != -1) {
                 /* FINALIZE */
                 struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
-                params.nth = node->n_tasks;
-                ggml_compute_forward(&params, node);
-                ggml_graph_compute_perf_stats_node(node, state->shared);
+                if (GGML_OP_HAS_FINALIZE[node->op]) {
+                    params.nth = node->n_tasks;
+                    ggml_compute_forward(&params, node);
+                    ggml_graph_compute_perf_stats_node(node, state->shared);
+                }
             }
 
             // distribute new work or execute it direct if 1T
@@ -16804,10 +16844,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 state->shared->perf_node_start_cycles  = ggml_perf_cycles();
                 state->shared->perf_node_start_time_us = ggml_perf_time_us();
 
+                params.nth = node->n_tasks;
+
                 /* INIT */
-                params.type = GGML_TASK_INIT;
-                params.nth  = node->n_tasks;
-                ggml_compute_forward(&params, node);
+                if (GGML_OP_HAS_INIT[node->op]) {
+                    params.type = GGML_TASK_INIT;
+                    ggml_compute_forward(&params, node);
+                }
 
                 if (node->n_tasks == 1) {
                     // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
@@ -16815,9 +16858,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                     params.type = GGML_TASK_COMPUTE;
                     ggml_compute_forward(&params, node);
 
-                    params.type = GGML_TASK_FINALIZE;
-                    ggml_compute_forward(&params, node);
-                    ggml_graph_compute_perf_stats_node(node, state->shared);
+                    if (GGML_OP_HAS_FINALIZE[node->op]) {
+                        params.type = GGML_TASK_FINALIZE;
+                        ggml_compute_forward(&params, node);
+                        ggml_graph_compute_perf_stats_node(node, state->shared);
+                    }
                 } else {
                     break;
                 }