]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add Q8_0 support (#2763)
authorGeorgi Gerganov <redacted>
Thu, 24 Aug 2023 13:19:57 +0000 (16:19 +0300)
committerGitHub <redacted>
Thu, 24 Aug 2023 13:19:57 +0000 (16:19 +0300)
* metal : add dequantize_q8_0 kernel

* metal : add mul_mat_q8_0_f32 kernel

* metal : add Q8_0 mul_mm kernel

ggml-metal.m
ggml-metal.metal

index 969cf7daa74c5e8d47a77dd1f83b0010b858b462..06eb3872e25e400efa07209bbac4253ea7a879ce 100644 (file)
@@ -63,6 +63,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DECL_KERNEL(get_rows_q8_0);
     GGML_METAL_DECL_KERNEL(get_rows_q2_K);
     GGML_METAL_DECL_KERNEL(get_rows_q3_K);
     GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +74,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +83,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -188,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+        GGML_METAL_ADD_KERNEL(get_rows_q8_0);
         GGML_METAL_ADD_KERNEL(get_rows_q2_K);
         GGML_METAL_ADD_KERNEL(get_rows_q3_K);
         GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -198,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -205,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -747,9 +753,10 @@ void ggml_metal_graph_compute(
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
                                 switch (src0->type) {
-                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
                                     case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
                                     case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
                                     case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
@@ -800,6 +807,15 @@ void ggml_metal_graph_compute(
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
                                         } break;
+                                    case GGML_TYPE_Q8_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
+                                        } break;
                                     case GGML_TYPE_Q2_K:
                                         {
                                             GGML_ASSERT(ne02 == 1);
@@ -871,7 +887,7 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
@@ -896,9 +912,10 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
-                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
+                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
                                 case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
                                 case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
                                 case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
index 7bc3fdf371897311707bb76d328ee7b22de5b2a1..82e1a0c7aca06e124898dc1c8caa7064bafef261 100644 (file)
@@ -18,6 +18,12 @@ typedef struct {
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
+#define QK8_0 32
+typedef struct {
+    half    d;         // delta
+    int8_t  qs[QK8_0]; // quants
+} block_q8_0;
+
 kernel void kernel_add(
         device const float * src0,
         device const float * src1,
@@ -357,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
     const int first_row = (r0 * nsg + sgitg) * nr;
     const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
     float yl[16];       // src1 vector cache
     float sumf[nr]={0.f};
 
@@ -429,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
+kernel void kernel_mul_mat_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = ne00/QK8_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * nsg + sgitg) * nr;
+    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];
+    float sumf[nr]={0.f};
+
+    const int ix = tiisg/2;
+    const int il = tiisg%2;
+
+    device const float * yb = y + ix * QK8_0 + 16*il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
+        for (int i = 0; i < 16; ++i) {
+            yl[i] = yb[i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+            float sumq = 0.f;
+            for (int iq = 0; iq < 16; ++iq) {
+                sumq += qs[iq] * yl[iq];
+            }
+            sumf[row] += sumq*x[ib+row*nb].d;
+        }
+
+        yb += QK8_0 * 16;
+    }
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+        }
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32(
         device const  char * src0,
         device const  char * src1,
@@ -480,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
-
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,
@@ -1621,12 +1688,12 @@ template <typename type4x4>
 void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
     const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = il ? (-8.h * 16.h) : -8.h;
+    const half m = il ? ( -8.h * 16.h) : -8.h;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
     const ushort mask1 = il ? 0xF000 : 0x0F00;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
+        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) + m) * d;
         reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
     }
 }
@@ -1640,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
     const ushort mask1 = il ? 0xF000 : 0x0F00;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
+        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) * d) + m;
         reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
     }
 }
 
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const half d = xb->d;
+
+    for (int i=0;i<16;i++) {
+        reg[i/4][i%4] = (qs[i + 16*il] * d);
+    }
+}
+
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
     const half d = xb->d;
@@ -1947,9 +2024,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
                           constant uint64_t &, constant uint64_t &, uint, uint, uint);
 
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
 template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1960,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
                              constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
                              constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
 
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;