]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add ability to use Q5_0, Q5_1, and IQ4_NL for quantized K cache (llama/6183)
authorKawrakow <redacted>
Thu, 21 Mar 2024 07:27:57 +0000 (08:27 +0100)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:20:00 +0000 (13:20 +0200)
* k_cache: be able to use Q5_0

* k_cache: be able to use Q5_1 on CODA

* k_cache: be able to use Q5_0 on Metal

* k_cache: be able to use Q5_1 on Metal

* k_cache: be able to use IQ4_NL - just CUDA for now

* k_cache: be able to use IQ4_NL on Metal

* k_cache: add newly added supported types to llama-bench and CUDA supports_op

---------

Co-authored-by: Iwan Kawrakow <redacted>
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal

index 27816a7be6ec2651871420141cc3bd0dcd867b2c..280839ea475cd92c25c2e60a5e8f4a365ad6dc68 100644 (file)
@@ -6757,6 +6757,123 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
     }
 }
 
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_0 * dsti = (block_q5_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK5_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_1 * dsti = (block_q5_1 *) cdsti;
+
+    float min = xi[0];
+    float max = xi[0];
+
+    for (int j = 1; j < QK5_1; ++j) {
+        const float v = xi[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d  = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (xi[0       + j] - min)*id;
+        const float x1 = (xi[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_NL; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    float d = vmax / kvalues_iq4nl[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = xi[0        + j]*id;
+        const float x1 = xi[QK4_NL/2 + j]*id;
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+        dsti->qs[j] = xi0 | (xi1 << 4);
+        const float v0 = kvalues_iq4nl[xi0];
+        const float v1 = kvalues_iq4nl[xi1];
+        const float w0 = xi[0        + j]*xi[0        + j];
+        const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
+        sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+    }
+
+    dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+
 template <cpy_kernel_t cpy_blck, int qk>
 static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -8490,6 +8607,39 @@ static void ggml_cpy_f32_q4_1_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_f32_q5_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_0 == 0);
+    const int num_blocks = ne / QK5_0;
+    cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_1 == 0);
+    const int num_blocks = ne / QK5_1;
+    cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_iq4_nl_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_NL == 0);
+    const int num_blocks = ne / QK4_NL;
+    cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -10888,6 +11038,12 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
         ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
         ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -11304,6 +11460,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
                     return true;
                 }
index 109e5fe6ba13dd45038a2dc7a881186c9d3b14b0..54d1f1d5a3cee85b2c30c4889a52b4c936f8e363 100644 (file)
@@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
-  //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
-  //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
     GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,              cpy_f32_q8_0,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,              cpy_f32_q4_0,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,              cpy_f32_q4_1,           true);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,              cpy_f32_q5_0,           true);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,              cpy_f32_q5_1,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,              cpy_f32_q5_0,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,              cpy_f32_q5_1,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,            cpy_f32_iq4_nl,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,               cpy_f16_f16,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,               cpy_f16_f32,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                    concat,                 true);
@@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
                            case GGML_TYPE_Q8_0:
                            case GGML_TYPE_Q4_0:
                            case GGML_TYPE_Q4_1:
+                           case GGML_TYPE_Q5_0:
+                           case GGML_TYPE_Q5_1:
+                           case GGML_TYPE_IQ4_NL:
                                 return true;
                            default:
                                 return false;
@@ -2431,13 +2436,14 @@ static enum ggml_status ggml_metal_graph_compute(
                                     GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
 
                                     switch (dstt) {
-                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline;  break;
-                                        case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;  break;
-                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
-                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
-                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
-                                      //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
-                                      //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+                                        case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline;  break;
+                                        case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;  break;
+                                        case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+                                        case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+                                        case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+                                        case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+                                        case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+                                        case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
                                         default: GGML_ASSERT(false && "not implemented");
                                     };
                                 } break;
index 63de563251d8935c10668c0d5d0177a58cc436a3..5552c8b28834a605daf41632f56043eeeb03adab 100644 (file)
@@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
     }
 }
 
+kernel void kernel_cpy_f32_q5_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+
+    device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK5_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_0].d = d;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK5_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_0].qh[j] = qh8[j];
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+
+    device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float max = src[0];
+        float min = src[0];
+
+        for (int j = 1; j < QK5_1; j++) {
+            const float v = src[j];
+            min = v < min ? v : min;
+            max = v > max ? v : max;
+        }
+
+        const float d = (max - min) / 31;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_1].d = d;
+        dst_data[i00/QK5_1].m = min;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_1].qh[j] = qh8[j];
+        }
+    }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+kernel void kernel_cpy_f32_iq4_nl(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+
+    device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / kvalues_iq4nl_f[0];
+        const float id = d ? 1.0f/d : 0.0f;
+
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            const float x0 = src[0        + j]*id;
+            const float x1 = src[QK4_NL/2 + j]*id;
+
+            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+            const float v0 = kvalues_iq4nl_f[xi0];
+            const float v1 = kvalues_iq4nl_f[xi1];
+            const float w0 = src[0        + j]*src[0        + j];
+            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+            sumq2 += w0*v0*v0 + w1*v1*v1;
+
+        }
+
+        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+
+    }
+}
+
 kernel void kernel_concat(
     device  const char * src0,
     device  const char * src1,
@@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
     }
 }
 
-constexpr constant static float kvalues_iq4nl_f[16] = {
-    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
 void kernel_mul_mv_iq4_nl_f32_impl(
         device const  void * src0,
         device const float * src1,