]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : add GGML_OP_REPEAT kernels (llama/7557)
authorGeorgi Gerganov <redacted>
Mon, 27 May 2024 09:10:19 +0000 (12:10 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
ggml-ci

src/ggml-metal.m
src/ggml-metal.metal

index 15fb68fc489af1c1ca14cf74b87ba5966d5b54d0..ff9ae55aada74247a20e14c330553ffdbc85cee5 100644 (file)
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_ROW,
     GGML_METAL_KERNEL_TYPE_DIV,
     GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F16,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I16,
     GGML_METAL_KERNEL_TYPE_SCALE,
     GGML_METAL_KERNEL_TYPE_SCALE_4,
     GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+        case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
         case GGML_OP_CLAMP:
         case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
             switch (dst->op) {
                 case GGML_OP_CONCAT:
                     {
-                        const int64_t nb = ne00;
-
                         id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
 
                         [encoder setComputePipelineState:pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                        [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
 
                         const int nth = MIN(1024, ne0);
 
@@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_MUL:
                 case GGML_OP_DIV:
                     {
+                        GGML_ASSERT(src0t == GGML_TYPE_F32);
+                        GGML_ASSERT(src1t == GGML_TYPE_F32);
+
                         const size_t offs = 0;
 
                         bool bcast_row = false;
 
-                        int64_t nb = ne00;
+                        int64_t nb = ne00; // used by the "row" kernels
 
                         id<MTLComputePipelineState> pipeline = nil;
 
@@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         }
                     } break;
+                case GGML_OP_REPEAT:
+                    {
+                        id<MTLComputePipelineState> pipeline;
+
+                        switch (src0t) {
+                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+                            case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+                            case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+                            default: GGML_ASSERT(false);
+                        }
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                        const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    } break;
                 case GGML_OP_ACC:
                     {
                         GGML_ASSERT(src0t == GGML_TYPE_F32);
index ce51c74d5158d8a629bfd4ca377ac4ed64ea99c4..174086b5b62939763b4afc1873d6d523914f3d49 100644 (file)
@@ -168,6 +168,53 @@ kernel void kernel_div(
     }
 }
 
+template<typename T>
+kernel void kernel_repeat(
+        device const char * src0,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 % ne03;
+    const int64_t i02 = i2 % ne02;
+    const int64_t i01 = i1 % ne01;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i00 = i0 % ne00;
+        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+    }
+}
+
+typedef decltype(kernel_repeat<float>) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_add_row(