]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add Q4_1 implementation (#1785)
authorKawrakow <redacted>
Sat, 10 Jun 2023 08:28:11 +0000 (11:28 +0300)
committerGitHub <redacted>
Sat, 10 Jun 2023 08:28:11 +0000 (11:28 +0300)
23.3 ms / token, so just ~1% slower than q4_0.
Achieves 290 GB/s memory throughput.

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 5c9ecd76e78c07630f19b2a09ac1b3c52b9666e3..167ebd467f01f56b42a11a419f09406b806a94cd 100644 (file)
@@ -50,12 +50,14 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_1);
     GGML_METAL_DECL_KERNEL(get_rows_q2_k);
     GGML_METAL_DECL_KERNEL(get_rows_q4_k);
     GGML_METAL_DECL_KERNEL(get_rows_q6_k);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
@@ -141,12 +143,14 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_1);
         GGML_METAL_ADD_KERNEL(get_rows_q2_k);
         GGML_METAL_ADD_KERNEL(get_rows_q4_k);
         GGML_METAL_ADD_KERNEL(get_rows_q6_k);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
@@ -545,6 +549,15 @@ void ggml_metal_graph_compute(
                                     nth1 = 8;
                                     [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
                                 } break;
+                            case GGML_TYPE_Q4_1:
+                                {
+                                    GGML_ASSERT(ne02 == 1);
+                                    GGML_ASSERT(ne12 == 1);
+
+                                    nth0 = 8;
+                                    nth1 = 8;
+                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                } break;
                             case GGML_TYPE_Q2_K:
                                 {
                                     GGML_ASSERT(ne02 == 1);
@@ -596,7 +609,7 @@ void ggml_metal_graph_compute(
                         [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
                         [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
-                        if (src0t == GGML_TYPE_Q4_0) {
+                        if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
                             [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                         } else if (src0t == GGML_TYPE_Q2_K) {
@@ -623,6 +636,7 @@ void ggml_metal_graph_compute(
                     switch (src0->type) {
                         case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                         case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
                         case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
                         case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
                         case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
index c94ef83f9e5e0f3d1816c1d6dd74eeb554bb26e3..ccd36386b583243335faa7849ab262ccd3c335e7 100644 (file)
@@ -11,6 +11,13 @@ typedef struct {
     uint8_t qs[QK4_0 / 2]; // nibbles / quants
 } block_q4_0;
 
+#define QK4_1 32
+typedef struct {
+    half d;          // delta
+    half m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+
 static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
     const int qk = QK4_0;
 
@@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
     }
 }
 
+static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+        const half m = x[i].m;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
 kernel void kernel_add(
         device const float * src0,
         device const float * src1,
@@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
+kernel void kernel_get_rows_q4_1(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_1(
+            (device const block_q4_1 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
 kernel void kernel_rms_norm(
         device const  void * src0,
         device       float * dst,
@@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
     //}
 }
 
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2  tpig[[thread_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_1;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const uint nth = tptg.x*tptg.y;
+    const uint ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int ix = tpitg.y/4;           // 0 or 1
+    const int iy = tpitg.y - 4*ix;      // 0...3
+
+    const int first = 4 * iy;
+
+    float sumf = 0;
+
+    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
+
+        const float d = (float)x[i].d;
+        const float m = (float)x[i].m;
+
+        device const uint8_t * xl = x[i].qs + first;
+        device const float   * yl = y + i * QK4_1 + first;
+
+        float2 acc = {0.0f, 0.0f};
+
+        for (int j = 0; j < 4; ++j) {
+
+            acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
+            acc[1] += yl[j+16] * (d * (xl[j] >>  4) + m);
+
+        }
+
+        sumf += acc[0] + acc[1];
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32(
         device const  char * src0,
         device const  char * src1,