]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : Add missing unary ops Metal support (#14660)
authorYavor Ivanov <redacted>
Sun, 13 Jul 2025 05:38:13 +0000 (22:38 -0700)
committerGitHub <redacted>
Sun, 13 Jul 2025 05:38:13 +0000 (08:38 +0300)
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 83a0739809a6e53b0ea1a4f92c07c4b47857a797..44ddc69d08f1ce752e2662cdfa351cac72c7c118 100644 (file)
@@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SILU,
     GGML_METAL_KERNEL_TYPE_SILU_4,
     GGML_METAL_KERNEL_TYPE_ELU,
+    GGML_METAL_KERNEL_TYPE_ABS,
+    GGML_METAL_KERNEL_TYPE_SGN,
+    GGML_METAL_KERNEL_TYPE_STEP,
+    GGML_METAL_KERNEL_TYPE_HARDSWISH,
+    GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
+    GGML_METAL_KERNEL_TYPE_EXP,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                            silu,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                          silu_4,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                             elu,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS,                             abs,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN,                             sgn,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP,                            step,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH,                       hardswish,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID,                     hardsigmoid,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP,                             exp,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                    soft_max_f16,                    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                  soft_max_f16_4,                  has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                    soft_max_f32,                    has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_ABS:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SGN:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_STEP:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_HARDSWISH:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_EXP:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
index 239ec31fbcb5892cc0b694f99f6f3a24ef35e517..13235e28852414436929a79611b78933a0e0daaf 100644 (file)
@@ -1199,6 +1199,51 @@ kernel void kernel_neg(
     dst[tpig] = -src0[tpig];
 }
 
+kernel void kernel_abs(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = fabs(src0[tpig]);
+}
+
+kernel void kernel_sgn(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
+}
+
+kernel void kernel_step(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
+}
+
+kernel void kernel_hardswish(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
+}
+
+kernel void kernel_hardsigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
+}
+
+kernel void kernel_exp(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = exp(src0[tpig]);
+}
+
 kernel void kernel_reglu(
         device const char * src0,
         device const char * src1,