]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : PP speedup (#3084)
authorKawrakow <redacted>
Mon, 11 Sep 2023 07:30:11 +0000 (09:30 +0200)
committerGitHub <redacted>
Mon, 11 Sep 2023 07:30:11 +0000 (10:30 +0300)
* Minor speed gains for all quantization types

* metal: faster kernel_scale via float4

* Various other speedups for "small" kernels

* metal: faster soft_max vial float4

* metal: faster diagonal infinity

Although, to me it looks like one should simply
fuse scale + diagnonal infinity + soft_max on the
KQtensor.

* Another faster f16 x f32 matrix multiply kernel

* Reverting the diag infinity change

It does work for PP, but somehow it fails for TG.
Need to look more into it.

* metal: add back faster diagonal infinity

This time more carefully

* metal : minor (readibility)

---------

Co-authored-by: Iwan Kawrakow <redacted>
Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.m
ggml-metal.metal

index b577d7f6088a425155f2df3f76ea8668e7d2c8bb..4f3f14e24d9d8cf9e37b4bbf0d28edac43cfea76 100644 (file)
@@ -63,7 +63,9 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
     GGML_METAL_DECL_KERNEL(soft_max);
+    GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -77,6 +79,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -218,7 +221,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
         GGML_METAL_ADD_KERNEL(soft_max);
+        GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -232,6 +237,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -286,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
     GGML_METAL_DEL_KERNEL(soft_max);
-    GGML_METAL_DEL_KERNEL(diag_mask_inf);
+    GGML_METAL_DEL_KERNEL(soft_max_4);
+    GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
     GGML_METAL_DEL_KERNEL(get_rows_f16);
     GGML_METAL_DEL_KERNEL(get_rows_q4_0);
     GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -300,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
+    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -767,7 +775,7 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
@@ -779,7 +787,7 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst);
+                                    const int64_t n = ggml_nelements(dst)/4;
 
                                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
@@ -799,7 +807,7 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst);
+                                    const int64_t n = ggml_nelements(dst)/4;
 
                                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
@@ -813,13 +821,16 @@ void ggml_metal_graph_compute(
                         {
                             const int nth = 32;
 
-                            [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            if (ne00%4 == 0) {
+                                [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -827,14 +838,23 @@ void ggml_metal_graph_compute(
                         {
                             const int n_past = ((int32_t *)(dst->op_params))[0];
 
-                            [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            if (ne00%8 == 0) {
+                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            if (ne00%8 == 0) {
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            }
+                            else {
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            }
                         } break;
                     case GGML_OP_MUL_MAT:
                         {
@@ -881,6 +901,7 @@ void ggml_metal_graph_compute(
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
+                                int nrows = 1;
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
@@ -890,8 +911,12 @@ void ggml_metal_graph_compute(
                                             nth1 = 1;
                                             if (ne11 * ne12 < 4) {
                                                 [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+                                                nrows = ne11;
                                             } else {
                                                 [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                                nrows = 4;
                                             }
                                         } break;
                                     case GGML_TYPE_Q4_0:
@@ -1012,7 +1037,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    int64_t ny = (ne11 + 3)/4;
+                                    int64_t ny = (ne11 + nrows - 1)/nrows;
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
index 7b5c21d92ab63fca5a784c3842042ad8c85a9e4a..f45b1490f5b492c2b8eea6010f4b9bf40179da49 100644 (file)
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
 }
 
 kernel void kernel_scale(
-        device const float * src0,
-        device       float * dst,
+        device const float4 * src0,
+        device       float4 * dst,
         constant     float & scale,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * scale;
 }
 
 kernel void kernel_silu(
-        device const float * src0,
-        device       float * dst,
+        device const float4 * src0,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
-    float x = src0[tpig];
+    device const float4 & x = src0[tpig];
     dst[tpig] = x / (1.0f + exp(-x));
 }
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A    = 0.044715f;
 constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 kernel void kernel_gelu(
-    device const float * src0,
-    device       float * dst,
+    device const float4 * src0,
+    device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
-    float x = src0[tpig];
+    device const float4 & x = src0[tpig];
 
     // BEWARE !!!
     // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        threadgroup float  * buf [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    buf[tpitg[0]] = -INFINITY;
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
-    }
-
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg[0]/2; i > 0; i /= 2) {
-        if (tpitg[0] < i) {
-            buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    float lmax = psrc0[tpitg[0]];
+    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+        lmax = MAX(lmax, psrc0[i00]);
     }
-
-    //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
-    //               the loop, and when that is done, buf[0] has the correct (synchronized) value
-    //if (tpitg[0] == 0) {
-    //    buf[0] = buf[0];
-    //}
-
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    const float max = buf[0];
+    const float max = simd_max(lmax);
 
     // parallel sum
-    buf[tpitg[0]] = 0.0f;
+    float lsum = 0.0f;
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         const float exp_psrc0 = exp(psrc0[i00] - max);
-        buf[tpitg[0]] += exp_psrc0;
+        lsum += exp_psrc0;
         // Remember the result of exp here. exp is expensive, so we really do not
         // whish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg[0]/2; i > 0; i /= 2) {
-        if (tpitg[0] < i) {
-            buf[tpitg[0]] += buf[tpitg[0] + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float sum = simd_sum(lsum);
+
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        pdst[i00] /= sum;
     }
+}
 
-    // broadcast - not needed, see above
-    //// broadcast
-    //if (tpitg[0] == 0) {
-    //    buf[0] = buf[0];
-    //}
+kernel void kernel_soft_max_4(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
 
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
-    const float sum = buf[0];
+    // parallel max
+    float4 lmax4 = psrc4[tpitg[0]];
+    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        lmax4 = fmax(lmax4, psrc4[i00]);
+    }
+    float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        pdst[i00] /= sum;
+    const float max = simd_max(lmax);
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        const float4 exp_psrc4 = exp(psrc4[i00] - max);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+    float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    const float sum = simd_sum(lsum);
+
+    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        pdst4[i00] /= sum;
     }
 }
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
         dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
     } else {
         dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+     }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne01,
+        constant        int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+
+    const int64_t i = 2*tpig[0];
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int64_t i4 = 4*i;
+    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        dst[i+1][k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            dst[i][k] = -INFINITY;
+        }
     }
 }
 
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mat_f16_f32_l4(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int nrows = ne11;
+    const int64_t r0 = tgpig.x;
+    const int64_t im = tgpig.z;
+
+    device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+        float sumf = 0;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+        }
+
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,
@@ -1800,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 
 template <typename type4x4>
 void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
-    const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = il ? ( -8.h * 16.h) : -8.h;
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = il ? 0xF000 : 0x0F00;
+    const ushort mask1 = mask0 << 8;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) + m) * d;
-        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
     }
+
 }
 
 template <typename type4x4>
 void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
-    const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = xb->m;
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = il ? 0xF000 : 0x0F00;
+    const ushort mask1 = mask0 << 8;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) * d) + m;
-        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
     }
 }
 
@@ -1858,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
-    const float d_all = (float)(xb->d);
+    const half d_all = xb->d;
     device const uint8_t * q = (device const uint8_t *)xb->qs;
     device const uint8_t * h = (device const uint8_t *)xb->hmask;
     device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1871,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
                                  ((il/4)>0 ? 12  : 3);
     uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
     uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
-    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
-                                 (scale_2&kmask2) | ((scale_1&kmask1) << 4);
-    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
+    const half ml = 4.h * dl;
 
-    il = (il/2)%4;
-    float   coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
 
     for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
     }
+
 #else
     float    kcoef = il&1 ? 1.f/16.f : 1.f;
     uint16_t kmask = il&1 ? 0xF0     : 0x0F;
@@ -1895,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
 #endif
 }
 
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
 template <typename type4x4>
 void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
-    device const uint8_t * q = xb->qs;
+    device const uchar * q = xb->qs;
 
 #if QK_K == 256
-    const float d = (float)(xb->d);
-    const float min = (float)(xb->dmin);
     short is = (il/4) * 2;
     q = q + (il/4) * 32 + 16 * (il&1);
-    il = il%4;
-    const uchar4 sc = get_scale_min_k4(is, xb->scales);
-    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
-    const float ml = il<2 ? min * sc[1] : min * sc[3];
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const half d   = il < 2 ? xb->d : xb->d / 16.h;
+    const half min = xb->dmin;
+    const half dl = d * sc[0];
+    const half ml = min * sc[1];
 #else
     q = q + 16 * (il&1);
     device const uint8_t * s = xb->scales;
     device const half2 * dh = (device const half2 *)xb->d;
     const float2 d = (float2)dh[0];
     const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
-    const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1 ]* (s[1]>>4);
+    const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1* (s[1]>>4);
 #endif
     const ushort mask = il<2 ? 0x0F : 0xF0;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * (q[i] & mask) - ml;
     }
+
 }
 
 template <typename type4x4>
@@ -1928,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
     device const uint8_t * qh = xb->qh;
 
 #if QK_K == 256
-    const float d = (float)(xb->d);
-    const float min = (float)(xb->dmin);
     short is = (il/4) * 2;
     q  = q + 32 * (il/4) + 16 * (il&1);
     qh = qh + 16 * (il&1);
     uint8_t ul = 1 << (il/2);
-    il = il%4;
-    const uchar4 sc = get_scale_min_k4(is, xb->scales);
-    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
-    const float ml = il<2 ? min * sc[1] : min * sc[3];
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const half d = il < 2 ? xb->d : xb->d / 16.h;
+    const half min = xb->dmin;
+    const half dl = d * sc[0];
+    const half ml = min * sc[1];
 
-    const ushort mask   = il<2 ? 0x0F : 0xF0;
-    const float  qh_val = il<2 ? 16.f : 256.f;
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    const half qh_val = il<2 ? 16.h : 256.h;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
     }
@@ -1959,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
-    const float d_all = (float)(xb->d);
+    const half d_all = xb->d;
     device const uint8_t * ql = (device const uint8_t *)xb->ql;
     device const uint8_t * qh = (device const uint8_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1967,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
 #if QK_K == 256
     ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
     qh = qh + 32*(il/8) + 16*(il&1);
-    float sc = scales[(il%2) + 2 * ((il/2))];
-    il = (il/2)%4;
+    half sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
 #else
     ql = ql + 16 * (il&1);
-    float sc = scales[il];
+    half sc = scales[il];
 #endif
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const half        coef = il>1 ? 1.f/16.h          : 1.h;
+    const half ml = d_all * sc * 32.h;
+    const half dl = d_all * sc * coef;
     for (int i = 0; i < 16; ++i) {
-        uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-        uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-        const float coef = il>1 ? 1.f/16.f          : 1.f;
-        float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
-                         ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
-        reg[i/4][i%4] = d_all * sc * q * coef;
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
     }
 }