]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : move dequantize templates to beginning of MSL source (llama/0)
authorGeorgi Gerganov <redacted>
Mon, 4 Nov 2024 11:43:32 +0000 (13:43 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
ggml/src/ggml-metal.metal

index 57eb34f13ac8548aa8164d8d2e09d78a24d4ce4c..3eb97663328d828b5911f561090a1bc69bf798b0 100644 (file)
@@ -12,435 +12,454 @@ using namespace metal;
 
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 
-enum ggml_sort_order {
-    GGML_SORT_ORDER_ASC,
-    GGML_SORT_ORDER_DESC,
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 };
 
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-kernel void kernel_add(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
-
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    float4x4 temp = *(((device float4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+    half4x4 temp = *(((device half4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
     }
 }
 
-kernel void kernel_sub(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+    }
+}
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
     }
 }
 
-kernel void kernel_mul(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = il ? 0x00F0 : 0x000F;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    const int x_mv = il ? 4 : 0;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[i/2][2*(i%2)+0] = d * x0 + md;
+        reg[i/2][2*(i%2)+1] = d * x1 + md;
     }
 }
 
-kernel void kernel_div(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = il ? 0x00F0 : 0x000F;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    const int x_mv = il ? 4 : 0;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
-    }
-}
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
 
-template<typename T>
-kernel void kernel_repeat(
-        device const char * src0,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
-
-    const int64_t i03 = i3 % ne03;
-    const int64_t i02 = i2 % ne02;
-    const int64_t i01 = i1 % ne01;
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i00 = i0 % ne00;
-        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+        reg[i/2][2*(i%2)+0] = d * x0 + m;
+        reg[i/2][2*(i%2)+1] = d * x1 + m;
     }
 }
 
-typedef decltype(kernel_repeat<float>) kernel_repeat_t;
-
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_add_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] + src1[tpig % nb];
-}
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const half d = xb->d;
 
-kernel void kernel_sub_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] - src1[tpig % nb];
+    for (int i = 0; i < 16; i++) {
+        reg[i/4][i%4] = (qs[i + 16*il] * d);
+    }
 }
 
-kernel void kernel_mul_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
-}
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+    const float d = xb->d;
+    const float min = xb->dmin;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    float dl, ml;
+    uint8_t sc = xb->scales[il];
 
-kernel void kernel_div_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] / src1[tpig % nb];
-}
+    q = q + 32*(il/8) + 16*(il&1);
+    il = (il/2)%4;
 
-kernel void kernel_scale(
-        device const float * src0,
-        device       float * dst,
-        constant     float & scale,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
+    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
 }
 
-kernel void kernel_scale_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        constant     float  & scale,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
-}
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    device const uint8_t * h = (device const uint8_t *)xb->hmask;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
 
-kernel void kernel_clamp(
-        device const float * src0,
-        device       float * dst,
-        constant     float & min,
-        constant     float & max,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
-}
+    q = q + 32 * (il/8) + 16 * (il&1);
+    h = h + 16 * (il&1);
+    uint8_t m = 1 << (il/2);
+    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+                                 ((il/4)>0 ? 12  : 3);
+    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    const float ml = 4.f * dl;
 
-kernel void kernel_relu(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
 
-kernel void kernel_sigmoid(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+    }
 }
 
-kernel void kernel_tanh(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = precise::tanh(x);
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
 }
 
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+    device const uchar * q = xb->qs;
 
-kernel void kernel_gelu(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
+    short is = (il/4) * 2;
+    q = q + (il/4) * 32 + 16 * (il&1);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
 }
 
-kernel void kernel_gelu_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q  = xb->qs;
+    device const uint8_t * qh = xb->qh;
 
-kernel void kernel_gelu_quick(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
+    short is = (il/4) * 2;
+    q  = q + 32 * (il/4) + 16 * (il&1);
+    qh = qh + 16 * (il&1);
+    uint8_t ul = 1 << (il/2);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d = il < 2 ? xb->d : xb->d / 16.f;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+    }
 }
 
-kernel void kernel_gelu_quick_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * ql = (device const uint8_t *)xb->ql;
+    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
 
-kernel void kernel_silu(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
+    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    qh = qh + 32*(il/8) + 16*(il&1);
+    float sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
 
-kernel void kernel_silu_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const float ml = d_all * sc * 32.f;
+    const float dl = d_all * sc * coef;
+    for (int i = 0; i < 16; ++i) {
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
+    }
 }
 
-kernel void kernel_sqr(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
+template <typename type4x4>
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
 }
 
-kernel void kernel_sqrt(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
 }
 
-kernel void kernel_sin(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * q3 = xb->qs + 8*ib32;
+    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
 }
 
-kernel void kernel_cos(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 8*ib32;
+    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+    }
+    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+    }
 }
 
-kernel void kernel_sum_rows(
-        device const float * src0,
-        device       float * dst,
+template <typename type4x4>
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * signs = qs + QK_K/8;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    const float d = xb->d;
+    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint16_t * qh = xb->qh;
+    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+    const uint16_t h = qh[ib32] >> 6*il;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = scale.f16;
+
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)xb->d * (ls - 32);
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+enum ggml_sort_order {
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         constant  int64_t & ne00,
         constant  int64_t & ne01,
         constant  int64_t & ne02,
@@ -465,132 +484,446 @@ kernel void kernel_sum_rows(
         constant uint64_t & nb1,
         constant uint64_t & nb2,
         constant uint64_t & nb3,
-        uint3 tpig[[thread_position_in_grid]]) {
-    int64_t i3 = tpig.z;
-    int64_t i2 = tpig.y;
-    int64_t i1 = tpig.x;
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
 
-    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
-        return;
-    }
-
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
 
-    float row_sum = 0;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
 
-    for (int64_t i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
     }
-
-    dst_row[0] = row_sum;
 }
 
-template<typename T>
-kernel void kernel_soft_max(
-        device const  char * src0,
-        device const  char * src1,
-        device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
-        threadgroup  float * buf [[threadgroup(0)]],
-        uint  tgpig[[threadgroup_position_in_grid]],
-        uint  tpitg[[thread_position_in_threadgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+kernel void kernel_sub(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
 
-    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
 
-    float slope = 1.0f;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
 
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const int64_t h = i02;
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    }
+}
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+kernel void kernel_mul(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
 
-        slope = pow(base, exp);
-    }
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
 
-    // parallel max
-    float lmax = -INFINITY;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
     }
+}
 
-    // find the max value in the block
-    float max_val = simd_max(lmax);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = -INFINITY;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+kernel void kernel_div(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
 
-        if (tiisg == 0) {
-            buf[sgitg] = max_val;
-        }
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 
-        max_val = buf[tiisg];
-        max_val = simd_max(max_val);
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
     }
+}
 
-    // parallel sum
-    float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
-        lsum += exp_psrc0;
-        pdst[i00] = exp_psrc0;
-    }
+template<typename T>
+kernel void kernel_repeat(
+        device const char * src0,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
 
-    // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
-    threadgroup_barrier(mem_flags::mem_none);
+    const int64_t i03 = i3 % ne03;
+    const int64_t i02 = i2 % ne02;
+    const int64_t i01 = i1 % ne01;
 
-    float sum = simd_sum(lsum);
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
 
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i00 = i0 % ne00;
+        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+    }
+}
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+typedef decltype(kernel_repeat<float>) kernel_repeat_t;
 
-        if (tiisg == 0) {
-            buf[sgitg] = sum;
-        }
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
 
-        sum = buf[tiisg];
-        sum = simd_sum(sum);
+kernel void kernel_sub_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb  [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb  [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant     float  & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+        device const float * src0,
+        device       float * dst,
+        constant     float & min,
+        constant     float & max,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_sqr(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = cos(src0[tpig]);
+}
+
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
     }
 
-    const float inv_sum = 1.0f/sum;
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        pdst[i00] *= inv_sum;
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
     }
+
+    dst_row[0] = row_sum;
 }
 
 template<typename T>
-kernel void kernel_soft_max_4(
+kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
@@ -612,12 +945,13 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     float slope = 1.0f;
 
+    // ALiBi
     if (max_bias > 0.0f) {
         const int64_t h = i02;
 
@@ -628,14 +962,13 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float lmax = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
-    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-
+    // find the max value in the block
     float max_val = simd_max(lmax);
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
@@ -655,14 +988,117 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel sum
-    float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
-        lsum4 += exp_psrc4;
-        pdst4[i00] = exp_psrc4;
+    float lsum = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+        lsum += exp_psrc0;
+        pdst[i00] = exp_psrc0;
     }
 
-    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        pdst[i00] *= inv_sum;
+    }
+}
+
+template<typename T>
+kernel void kernel_soft_max_4(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+    float slope = 1.0f;
+
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    }
+
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
     // This barrier fixes a failing test
     // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
@@ -3339,10 +3775,6 @@ static inline int best_index_int8(int n, constant float * val, float x) {
     return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
-constexpr constant static float kvalues_iq4nl_f[16] = {
-    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
 kernel void kernel_cpy_f32_iq4_nl(
         device const float * src0,
         device        void * dst,
@@ -5457,440 +5889,6 @@ kernel void kernel_mul_mv_iq4_xs_f32(
     kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
-//============================= templates and their specializations =============================
-
-// NOTE: this is not dequantizing - we are simply fitting the template
-template <typename type4x4>
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
-    float4x4 temp = *(((device float4x4 *)src));
-    for (int i = 0; i < 16; i++){
-        reg[i/4][i%4] = temp[i/4][i%4];
-    }
-}
-
-template <typename type4x4>
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
-    half4x4 temp = *(((device half4x4 *)src));
-    for (int i = 0; i < 16; i++){
-        reg[i/4][i%4] = temp[i/4][i%4];
-    }
-}
-
-template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
-    const float d1 = il ? (xb->d / 16.h) : xb->d;
-    const float d2 = d1 / 256.f;
-    const float md = -8.h * xb->d;
-    const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = mask0 << 8;
-
-    for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
-        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
-    const float d1 = il ? (xb->d / 16.h) : xb->d;
-    const float d2 = d1 / 256.f;
-    const float  m = xb->m;
-    const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = mask0 << 8;
-
-    for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
-        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
-    const float d = xb->d;
-    const float md = -16.h * xb->d;
-    const ushort mask = il ? 0x00F0 : 0x000F;
-
-    const uint32_t qh = *((device const uint32_t *)xb->qh);
-
-    const int x_mv = il ? 4 : 0;
-
-    const int gh_mv = il ? 12 : 0;
-    const int gh_bk = il ?  0 : 4;
-
-    for (int i = 0; i < 8; i++) {
-        // extract the 5-th bits for x0 and x1
-        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
-        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
-        // combine the 4-bits from qs with the 5th bit
-        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
-        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
-        reg[i/2][2*(i%2)+0] = d * x0 + md;
-        reg[i/2][2*(i%2)+1] = d * x1 + md;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
-    const float d = xb->d;
-    const float m = xb->m;
-    const ushort mask = il ? 0x00F0 : 0x000F;
-
-    const uint32_t qh = *((device const uint32_t *)xb->qh);
-
-    const int x_mv = il ? 4 : 0;
-
-    const int gh_mv = il ? 12 : 0;
-    const int gh_bk = il ?  0 : 4;
-
-    for (int i = 0; i < 8; i++) {
-        // extract the 5-th bits for x0 and x1
-        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
-        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
-        // combine the 4-bits from qs with the 5th bit
-        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
-        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
-        reg[i/2][2*(i%2)+0] = d * x0 + m;
-        reg[i/2][2*(i%2)+1] = d * x1 + m;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
-    device const int8_t * qs = ((device const int8_t *)xb->qs);
-    const half d = xb->d;
-
-    for (int i = 0; i < 16; i++) {
-        reg[i/4][i%4] = (qs[i + 16*il] * d);
-    }
-}
-
-template <typename type4x4>
-void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
-    const float d = xb->d;
-    const float min = xb->dmin;
-    device const uint8_t * q = (device const uint8_t *)xb->qs;
-    float dl, ml;
-    uint8_t sc = xb->scales[il];
-
-    q = q + 32*(il/8) + 16*(il&1);
-    il = (il/2)%4;
-
-    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
-    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
-    const half d_all = xb->d;
-    device const uint8_t * q = (device const uint8_t *)xb->qs;
-    device const uint8_t * h = (device const uint8_t *)xb->hmask;
-    device const int8_t * scales = (device const int8_t *)xb->scales;
-
-    q = q + 32 * (il/8) + 16 * (il&1);
-    h = h + 16 * (il&1);
-    uint8_t m = 1 << (il/2);
-    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
-                                 ((il/4)>0 ? 12  : 3);
-    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
-    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
-    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
-                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
-    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
-    const float ml = 4.f * dl;
-
-    il = (il/2) & 3;
-    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
-    dl *= coef;
-
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
-    }
-}
-
-static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
-    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
-                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
-}
-
-template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
-    device const uchar * q = xb->qs;
-
-    short is = (il/4) * 2;
-    q = q + (il/4) * 32 + 16 * (il&1);
-    il = il & 3;
-    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const float d   = il < 2 ? xb->d : xb->d / 16.h;
-    const float min = xb->dmin;
-    const float dl = d * sc[0];
-    const float ml = min * sc[1];
-
-    const ushort mask = il<2 ? 0x0F : 0xF0;
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
-    device const uint8_t * q  = xb->qs;
-    device const uint8_t * qh = xb->qh;
-
-    short is = (il/4) * 2;
-    q  = q + 32 * (il/4) + 16 * (il&1);
-    qh = qh + 16 * (il&1);
-    uint8_t ul = 1 << (il/2);
-    il = il & 3;
-    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const float d = il < 2 ? xb->d : xb->d / 16.f;
-    const float min = xb->dmin;
-    const float dl = d * sc[0];
-    const float ml = min * sc[1];
-
-    const ushort mask  = il<2 ? 0x0F : 0xF0;
-    const float qh_val = il<2 ? 16.f : 256.f;
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
-    const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
-    device const int8_t * scales = (device const int8_t *)xb->scales;
-
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
-    float sc = scales[(il%2) + 2 * ((il/2))];
-    il = (il/2) & 3;
-
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
-    const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
-    device const uint16_t * q2 = xb->qs + 4*ib32;
-    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
-    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
-    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
-    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
-    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
-    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
-    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
-    for (int i = 0; i < 8; ++i) {
-        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint16_t * q2 = xb->qs + 4*ib32;
-    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
-    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
-    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
-    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
-    for (int i = 0; i < 8; ++i) {
-        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * q3 = xb->qs + 8*ib32;
-    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
-    const uint32_t aux32 = gas[0] | (gas[1] << 16);
-    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
-    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
-    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
-        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
-    }
-    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
-    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
-    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
-    for (int i = 0; i < 4; ++i) {
-        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
-        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * qs = xb->qs + 8*ib32;
-    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
-    const uint8_t qh = xb->qh[ib32] >> 4*il;
-    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
-    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
-        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
-    }
-    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
-    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
-    for (int i = 0; i < 4; ++i) {
-        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
-        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint8_t * signs = qs + QK_K/8;
-    const uint8_t qh = xb->qh[ib32] >> 4*il;
-    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
-        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    const float d = xb->d;
-    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint16_t * qh = xb->qh;
-    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
-    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
-    const uint16_t h = qh[ib32] >> 6*il;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
-        reg[1][i] = dl * (grid1[i] >>  4) + ml;
-        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
-        reg[3][i] = dl * (grid2[i] >>  4) + ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    device const uint16_t * sc = (device const uint16_t *)xb->scales;
-
-    iq1m_scale_t scale;
-    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-    const float d = scale.f16;
-
-    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint8_t * qh = xb->qh + 2*ib32 + il;
-
-    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
-    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
-        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
-        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
-        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
-    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
-    const float d = xb->d;
-    uint32_t aux32;
-    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
-    for (int i = 0; i < 4; ++i) {
-        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
-        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
-        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
-        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
-        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
-    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
-    const float d = (float)xb->d * (ls - 32);
-    uint32_t aux32;
-    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
-    for (int i = 0; i < 4; ++i) {
-        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
-        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
-        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
-        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
-        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
-    }
-}
-
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows_q(
         device const  void * src0,