]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add GELU implementation (#1770)
authorAT <redacted>
Fri, 9 Jun 2023 08:00:51 +0000 (04:00 -0400)
committerGitHub <redacted>
Fri, 9 Jun 2023 08:00:51 +0000 (11:00 +0300)
Co-authored-by: Adam Treat <redacted>
ggml-metal.m
ggml-metal.metal

index 54cbaf860d0c850f076ce2b7de701f4bad63c038..5c9ecd76e78c07630f19b2a09ac1b3c52b9666e3 100644 (file)
@@ -45,6 +45,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(relu);
+    GGML_METAL_DECL_KERNEL(gelu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
@@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(relu);
+        GGML_METAL_ADD_KERNEL(gelu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
@@ -420,6 +422,20 @@ void ggml_metal_graph_compute(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+            case GGML_OP_GELU:
+            {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
             case GGML_OP_SOFT_MAX:
                 {
                     if (encoder == nil) {
index 8e730eb9c33e5d2fbe08a2c1bdd607e98f5e8226..745fe8ad30cd7f74005a3edd13202532f60ba873 100644 (file)
@@ -81,6 +81,17 @@ kernel void kernel_relu(
     dst[tpig] = max(0.0f, src0[tpig]);
 }
 
+constant float GELU_COEF_A    = 0.044715f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
 kernel void kernel_soft_max(
         device const float * src0,
         device       float * dst,