]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add `GGML_UNARY_OP_ELU` kernel (ggml/1018)
authorPAB <redacted>
Mon, 18 Nov 2024 09:02:49 +0000 (10:02 +0100)
committerGeorgi Gerganov <redacted>
Tue, 19 Nov 2024 18:03:21 +0000 (20:03 +0200)
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 58fee4bfd1296cd335e70e01ffc2c3f19bfc4943..d1abb3cef0ec4a375755e3ae1cf7d3e0d98f0846 100644 (file)
@@ -126,6 +126,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
     GGML_METAL_KERNEL_TYPE_SILU_4,
+    GGML_METAL_KERNEL_TYPE_ELU,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -649,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                           elu,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   has_simdgroup_reduction);
@@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_ELU:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_ELU:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
index 86fdf1c18cfb692f82a6c2400d6712777e8c9f27..819b20ba89bec6ffc79696680306716e6d4eb324 100644 (file)
@@ -782,6 +782,14 @@ kernel void kernel_silu_4(
     dst[tpig] = x / (1.0f + exp(-x));
 }
 
+kernel void kernel_elu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
+}
+
 kernel void kernel_sqr(
         device const float * src0,
         device       float * dst,