]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add Q2_K implementation (#1762)
authorKawrakow <redacted>
Thu, 8 Jun 2023 19:28:21 +0000 (22:28 +0300)
committerGitHub <redacted>
Thu, 8 Jun 2023 19:28:21 +0000 (22:28 +0300)
* metal : add Q2_K implementation

27.1 ms / token on M2 Max 30-core GPU, so about the
same speed as Q4_0. Memory throughput is ~156 GB/s.

The access pattern used in the Q2_K
CUDA implementation resulted in significantly lower
performance (~31 ms/token).

* Fixing merge conflicts

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 626ca871cd5729cbbda92552516c86bff76319f8..ac4f1346c8bcc80cabf4f66799016dfaecc9c5c6 100644 (file)
@@ -49,11 +49,13 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
     GGML_METAL_DECL_KERNEL(get_rows_q4_k);
     GGML_METAL_DECL_KERNEL(get_rows_q6_k);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
     GGML_METAL_DECL_KERNEL(rope);
@@ -137,11 +139,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
         GGML_METAL_ADD_KERNEL(get_rows_q4_k);
         GGML_METAL_ADD_KERNEL(get_rows_q6_k);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
         GGML_METAL_ADD_KERNEL(rope);
@@ -525,6 +529,15 @@ void ggml_metal_graph_compute(
                                     nth1 = 4;
                                     [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
                                 } break;
+                            case GGML_TYPE_Q2_K:
+                                {
+                                    GGML_ASSERT(ne02 == 1);
+                                    GGML_ASSERT(ne12 == 1);
+
+                                    nth0 = 4;
+                                    nth1 = 16;
+                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                } break;
                             case GGML_TYPE_Q4_K:
                                 {
                                     GGML_ASSERT(ne02 == 1);
@@ -570,6 +583,9 @@ void ggml_metal_graph_compute(
                         if (src0t == GGML_TYPE_Q4_0) {
                             [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                        } else if (src0t == GGML_TYPE_Q2_K) {
+                            [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                         } else if (src0t == GGML_TYPE_Q4_K) {
                             [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -591,6 +607,7 @@ void ggml_metal_graph_compute(
                     switch (src0->type) {
                         case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                         case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
                         case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
                         case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
                         default: GGML_ASSERT(false && "not implemented");
index e851cbd4de82b9b671886d255c771f167fe677a6..43814ed09bf9b5a7ef8a6fe37df8b8d858f825de 100644 (file)
@@ -527,6 +527,13 @@ kernel void kernel_cpy_f32_f32(
 
 #define QK_K 256
 
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half d;           // super-block scale for quantized scales
+    half dmin;        // super-block scale for quantized mins
+} block_q2_k;
+
 typedef struct {
     half d;             // super-block scale for quantized scales
     half dmin;          // super-block scale for quantized mins
@@ -555,6 +562,41 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
     return r;
 }
 
+//========================================== dequantization =============================
+
+static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = x[i].d;
+        const float min = x[i].dmin;
+
+        device const uint8_t * q = x[i].qs;
+
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+
+    }
+}
+
 static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -586,12 +628,12 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
 
     for (int i = 0; i < nb; i++) {
 
-        const float d = x[i].d;
-
         device const uint8_t * ql = x[i].ql;
         device const uint8_t * qh = x[i].qh;
         device const int8_t  * sc = x[i].scales;
 
+        const float d = x[i].d;
+
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
                 int is = l/16;
@@ -612,6 +654,22 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
     }
 }
 
+kernel void kernel_get_rows_q2_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q2_k(
+            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
 kernel void kernel_get_rows_q4_k(
         device const  void * src0,
         device const   int * src1,
@@ -628,6 +686,129 @@ kernel void kernel_get_rows_q4_k(
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
+kernel void kernel_get_rows_q6_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q6_k(
+            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+//====================================== dot products =========================
+
+kernel void kernel_mul_mat_q2_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2  tpig[[thread_position_in_grid]],               // we don't use this for now
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+
+    const int tid = tpitg.y;    // 0...16
+    const int il  = tid/4;      // 0...3
+    const int ir  = tid%4;      // 0...3
+    const int ip  = il/2;       // 0 or 1
+    const int shift1 = 4*(il%2);// 0 or 4
+    const int shift2 = shift1+2;// 2 or 6
+    const int n   = 8;
+    const int is  = 4*il + (n*ir)/16;
+
+    sum[ith] = 0.0f;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q = x[i].qs + 32*ip + n*ir;
+        device const uint8_t * scales = x[i].scales + is;
+
+        uint8_t d1 = scales[0] & 0xF;
+        uint8_t m1 = scales[0] >>  4;
+        uint8_t d2 = scales[2] & 0xF;
+        uint8_t m2 = scales[2] >>  4;
+
+        device const float   * y = yy + i*QK_K + 64*il + n*ir;
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
+            s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
+        }
+        sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
+
+
+    }
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    // This version is slightly faster than the commented out one below,
+    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+    //// accumulate the sum from all threads in the threadgroup
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (uint i = nth/2; i > 0; i /= 2) {
+    //    if (ith < i) {
+    //        sum[ith] += sum[ith + i];
+    //    }
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    //if (ith == 0) {
+    //    dst[r1*ne0 + r0] = sum[0];
+    //}
+}
+
 kernel void kernel_mul_mat_q4_k_f32(
         device const  void * src0,
         device const float * src1,
@@ -724,22 +905,6 @@ kernel void kernel_mul_mat_q4_k_f32(
     //}
 }
 
-kernel void kernel_get_rows_q6_k(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q6_k(
-            (device const block_q6_k *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
 kernel void kernel_mul_mat_q6_k_f32(
         device const  void * src0,
         device const float * src1,