]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add norm, cpy f16->f16, alibi kernels (#1823)
authorAaron Miller <redacted>
Sat, 17 Jun 2023 14:37:49 +0000 (07:37 -0700)
committerGitHub <redacted>
Sat, 17 Jun 2023 14:37:49 +0000 (17:37 +0300)
ggml-metal.m
ggml-metal.metal

index 0e9b56aa33efa5445e66b0433316f71a2ac2ac84..8148512037d44e0d2eff77751c4c0ecd630b561b 100644 (file)
@@ -57,6 +57,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q5_k);
     GGML_METAL_DECL_KERNEL(get_rows_q6_k);
     GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
@@ -66,8 +67,10 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
     GGML_METAL_DECL_KERNEL(rope);
+    GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f16);
 
 #undef GGML_METAL_DECL_KERNEL
 };
@@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(get_rows_q5_k);
         GGML_METAL_ADD_KERNEL(get_rows_q6_k);
         GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
@@ -171,8 +175,10 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
         GGML_METAL_ADD_KERNEL(rope);
+        GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f16);
 
 #undef GGML_METAL_ADD_KERNEL
     }
@@ -735,6 +741,65 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-5f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ALIBI:
+                        {
+                            GGML_ASSERT((src0t == GGML_TYPE_F32));
+                            const int   n_past   = ((int32_t *) src1->data)[0];
+                            const int   n_head   = ((int32_t *) src1->data)[1];
+                            const float max_bias = ((float *)   src1->data)[2];
+                            if (__builtin_popcount(n_head) != 1) {
+                                GGML_ASSERT(false && "only power-of-two n_head implemented");
+                            }
+                            const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+                            const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+                            [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+                            const int nth = 32;
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_ROPE:
                         {
                             if (encoder == nil) {
@@ -788,6 +853,14 @@ void ggml_metal_graph_compute(
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
+                                case GGML_TYPE_F16:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
index 09e12a879a1154f75bb4a9822a4c1041b8342ae9..d1e49222db2eb6c9c7614978fc88c349b86ed5e9 100644 (file)
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
+kernel void kernel_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    // MEAN
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float mean  = sum[0];
+
+    // recenter
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+    }
+
+    // VARIANCE
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += y[i00] * y[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float variance = sum[0];
+
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+
 kernel void kernel_rms_norm(
         device const  void * src0,
         device       float * dst,
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
+kernel void kernel_alibi_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant      float & m0,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    float m_k = pow(m0, i2 + 1);
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+    }
+}
+
 kernel void kernel_rope(
         device const  void * src0,
         device       float * dst,
@@ -540,6 +648,47 @@ kernel void kernel_rope(
     }
 }
 
+kernel void kernel_cpy_f16_f16(
+        device const half * src0,
+        device       half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
 kernel void kernel_cpy_f32_f16(
         device const float * src0,
         device        half * dst,