]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal: add neg operator (llama/13029)
authorJeffrey Morgan <redacted>
Sun, 20 Apr 2025 05:28:40 +0000 (22:28 -0700)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 85f3ae7bfdc31d42231f6b4aac1b590f87a1985c..266d8af4693c211526cd8da2d3ff581abe061f32 100644 (file)
@@ -481,6 +481,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_NEG:
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_NEG:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
index dc7eab03ee8a29c187235ce976ac66eb6b65bd6f..8d6e99e621e9e3933b149a15e6eb1007f593ab90 100644 (file)
@@ -949,6 +949,13 @@ kernel void kernel_cos(
     dst[tpig] = cos(src0[tpig]);
 }
 
+kernel void kernel_neg(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = -src0[tpig];
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,